Source code for kerch.level.single_view.Ridge

from .LSSVM import LSSVM
from ... import utils
import torch


[docs] class Ridge(LSSVM): @utils.extend_docstring(LSSVM) @utils.kwargs_decorator({ "requires_bias": False }) def __init__(self, *args, **kwargs): if kwargs["requires_bias"]: kwargs["requires_bias"] = False self._logger.warning('A ridge regression has no bias term. ' 'The bias parameter is overwritten to False.') super(Ridge, self).__init__(*args, **kwargs) def _solve_primal(self, target=None) -> None: C = self.kernel.C phi = self.kernel.phi_sample dev = C.device dim_output = phi.shape[1] I = torch.eye(dim_output, dtype=utils.FTYPE, device=dev) Y = phi.t() @ target A = C + (1 / self._gamma) * I sol = torch.linalg.solve(A, Y) self.weight = sol def _solve_dual(self, target=None) -> None: K = self.kernel.K dev = K.device I = torch.eye(self.num_sample, dtype=utils.FTYPE, device=dev) A = K + (1 / self._gamma) * I sol = torch.linalg.solve(A, target) self.hidden = sol