from typing import Iterator
import torch
from torch import Tensor as T
from .Level import Level
from ... import utils
[docs]
class LSSVM(Level):
r"""
Least squares support vector machine.
:param gamma: Regularization parameter of the LSSVM. Defaults to 1.
:type gamma: float, optional
"""
@utils.extend_docstring(Level)
@utils.kwargs_decorator({
"requires_bias": True
})
def __init__(self, *args, **kwargs):
super(LSSVM, self).__init__(*args, **kwargs)
gamma = kwargs.pop('gamma', 1.)
self._gamma = torch.nn.Parameter(torch.tensor(gamma, dtype=utils.FTYPE))
self._mse_loss = torch.nn.MSELoss()
def __str__(self):
return "LSSVM with " + Level.__str__(self)
@property
def gamma(self) -> float:
return self._gamma.data.cpu().numpy().item()
@gamma.setter
def gamma(self, val):
val = utils.castf(val, dev=self._gamma.device, tensor=False)
self._gamma.data = val
self._reset_dual()
self._reset_primal()
def _center_hidden(self):
if self._dual_param_exists:
self._dual_param.data -= torch.mean(self._dual_param.data, dim=1)
else:
self._logger.debug("The hidden variables cannot be centered as they are not set.")
def _solve_primal(self) -> None:
C = self.kernel.C
phi = self.kernel.phi_sample
dev = C.device
dim_output = phi.shape[1]
N = torch.tensor([[self.num_sample]],
dtype=utils.FTYPE,
device=dev)
P = torch.sum(phi, dim=0, keepdim=True)
S = torch.sum(self.current_target, dim=0, keepdim=True)
Y = phi.t() @ self.current_target
A = torch.cat((torch.cat((C + self._gamma * self._I_primal, P.t()), dim=1),
torch.cat((P, N), dim=1)), dim=0)
B = torch.cat((Y, S), dim=0)
sol = torch.linalg.solve(A, B)
weight = sol[0:-1].data
bias = sol[-1].data
self.weight = weight
self.bias = bias
def _solve_dual(self) -> None:
K = self.kernel.K
dev = K.device
Ones = torch.ones((self.num_sample, 1),
dtype=utils.FTYPE,
device=dev)
Zero = torch.zeros((1, 1),
dtype=utils.FTYPE,
device=dev)
Zeros = torch.zeros((1, self.dim_output),
dtype=utils.FTYPE,
device=dev)
N1 = Ones
N2 = self.current_target
A = torch.cat((torch.cat((K + self._gamma * self._I_dual, N1), dim=1),
torch.cat((N1.t(), Zero), dim=1)), dim=0)
B = torch.cat((N2, Zeros), dim=0)
sol = torch.linalg.solve(A, B)
hidden = sol[0:-1].data
bias = sol[-1].data
self.update_dual(hidden, idx_sample=self.idx)
self.bias = bias
def _euclidean_parameters(self, recurse=True) -> Iterator[torch.nn.Parameter]:
yield from super(LSSVM, self)._euclidean_parameters(recurse)
if self._representation == 'primal':
if self._primal_param_exists:
yield self._weight
yield self._bias
else:
if self._dual_param_exists:
yield self._dual_param
yield self._bias
@property
def H(self) -> torch.Tensor:
return self.dual_param
@property
def W(self) -> torch.Tensor:
return self.primal_param
[docs]
def loss(self, representation=None) -> T:
fact = 1 / self.num_idx
return fact * self._loss_regularization(representation) \
+ self.gamma * self._loss_mse(representation)
def _loss_regularization(self, representation=None):
representation = utils.check_representation(representation, self._representation, self)
level_key = "Level_subloss_default_representation" if self._representation == representation \
else "Level_subloss_representation"
def fun():
if representation == 'primal':
weight = self.weight
return torch.einsum('ij,ji',weight, weight)
# torch.trace(weight.T @ weight)
else:
hidden = self.hidden
return torch.einsum('ji,jk,ki',hidden,self.K,hidden)
# torch.trace(hidden.T @ self.K @ hidden)
return self._get(key='subloss_regularization_' + representation,
level_key=level_key, fun=fun)
def _loss_mse(self, representation=None):
representation = utils.check_representation(representation, self._representation, self)
level_key = "Level_subloss_default_representation" if self._representation == representation \
else "Level_subloss_representation"
def fun():
pred = self._forward(representation=representation)
return self._mse_loss(pred, self.current_target)
return self._get(key='subloss_mse_' + representation,
level_key=level_key, fun=fun)
[docs]
def losses(self, representation=None) -> dict:
return {'Regularization': self._loss_regularization().data.detach().cpu().item(),
'MSE': self._loss_mse().data.detach().cpu().item(),
**super(LSSVM, self).losses()}
[docs]
def after_step(self) -> None:
super(LSSVM, self).after_step()
self._center_hidden()
def _update_dual_from_primal(self):
raise NotImplementedError