import torch
from torch import Tensor as T
from .MVLevel import MVLevel
from .._KPCA import _KPCA
from ... import utils
[docs]
class MVKPCA(_KPCA, MVLevel):
r"""
Multi-View Kernel Principal Component Analysis.
"""
@utils.extend_docstring(_KPCA)
@utils.extend_docstring(MVLevel)
def __init__(self, *args, **kwargs):
super(MVKPCA, self).__init__(*args, **kwargs)
def __str__(self):
return "multi-view KPCA(" + MVLevel.__str__(self) + "\n)"
def _project_primal(self, known, to_predict):
phi_known = self.phi(known)
weight_known = self.weights_by_name(list(known.keys()))
weight_predict = self.weights_by_name(to_predict)
Inv = torch.linalg.inv(torch.diag(self.vals) - weight_predict.T @ weight_predict)
return phi_known @ weight_known @ Inv @ weight_predict.T
def _project_dual(self, known, to_predict):
k_known = self.k(known)
K_predict = self.k(to_predict)
sqrt_vals = torch.sqrt(self.vals).unsqueeze(0)
Norm = sqrt_vals.T @ sqrt_vals
Inv = torch.linalg.inv(torch.diag(self.vals) - self.H.T @ K_predict @ self.H)
# Inv = torch.linalg.inv(self.H.T @ K_predict @ self.H / Norm)
# return K_predict @ self.H.T @ Inv @ self.H @ k_known
return k_known @ self.H @ Inv @ self.H.T @ K_predict
def _project(self, known: dict, representation: str):
r"""
Predicts the feature map of the known not specified in the inputs, based on the values specified in the
inputs.
:param known: Dictionary of the inputs where the key is the view identifier (``str`` or ``int``) and the
values the inputs to the known.
:type known: dict
:return:
:rtype: Tensor
"""
representation = utils.check_representation(representation, default=self._representation, cls=self)
# CONSISTENCY
num_points_known = None
to_predict = []
for key, _ in self.named_views:
if key in known:
value = known[key]
# verify consistency of number of datapoints across the various provided inputs for the known.
if value is not None:
if num_points_known is None:
num_points_known = value.shape[0]
else:
assert num_points_known == value.shape[0], \
f"Inconsistent number of datapoints to predict across the " \
f"different known: {num_points_known} and {value.shape[0]}."
else:
to_predict.append(key)
assert num_points_known is not None, 'Nothing to predict.'
# PREDICTION
switcher = {'primal': self._project_primal,
'dual': self._project_dual}
return switcher.get(representation, 'Error with the specified representation')(known, to_predict), to_predict
[docs]
def project(self, known: dict, representation: str = None) -> T:
representation = utils.check_representation(representation, default=self._representation, cls=self)
return self._project(known, representation)[0]
[docs]
@utils.kwargs_decorator(
{"representation": "dual",
"method": "smoother",
"knn": 1,
}
)
def predict(self, known, **kwargs):
representation = utils.check_representation(kwargs["representation"], default=self._representation, cls=self)
transform, to_predict = self._project(known, representation)
method = kwargs["method"]
sol = {}
if representation == 'primal':
dim = 0
for view, name in zip(self.views_by_name(to_predict), to_predict):
view_phi = transform[:, dim:view.dim_feature]
dim = view.dim_feature
if method == 'smoother':
sol[name] = view.kernel.implicit_preimage(view_phi @ view.phi().T, kwargs["knn"])
elif method == 'pinv':
sol[name] = view.kernel.explicit_preimage(view_phi)
else:
raise NotImplementedError
elif representation == 'dual':
for view, name in zip(self.views_by_name(to_predict), to_predict):
if method == 'smoother':
sol[name] = view.kernel.implicit_preimage(transform, kwargs["knn"])
else:
raise NotImplementedError
return sol
def _update_dual_from_primal(self):
self.hidden = sum([v(representation='primal') for v in self.views]) @ torch.diag(1 / self.vals)
# def predict_opt(self, inputs: dict, representation='dual', lr: float = .001, tot_iter: int = 500) -> dict:
# # initiate parameters
# num_predict = None
# to_predict = []
# for key in self.views:
# if key in inputs:
# value = inputs[key]
# # verify consistency of number of datapoints across the various known.
# if num_predict is None:
# num_predict = value.shape[0]
# else:
# assert num_predict == value.shape[0], f"Inconsistent number of datapoints to predict across the " \
# f"different known: {num_predict} and {value.shape[0]}."
# else:
# to_predict.append(key)
#
# # if nothing is given, only one datapoint is predicted
# if num_predict is None:
# num_predict = 1
#
# # initialize the other datapoints to be predicted
# params = torch.nn.ParameterList([])
#
# def init_primal(params):
# for key in to_predict:
# v = self.view(key)
# inputs[key] = torch.nn.Parameter(
# torch.zeros((num_predict, v.dim_input), dtype=utils.FTYPE),
# requires_grad=True)
# params.append(inputs[key])
# return MVKPCA._primal_obj, params
#
# def init_dual(params):
# for key in to_predict:
# v = self.view(key)
# inputs[key] = torch.nn.Parameter(
# torch.zeros((num_predict, v.dim_input), dtype=utils.FTYPE),
# requires_grad=True)
# params.append(inputs[key])
# return MVKPCA._dual_obj, params
#
# switcher = {'primal': init_primal,
# 'dual': init_dual}
# if representation in switcher:
# fun, params = switcher.get(representation)(params)
# else:
# raise NameError('Invalid representation (must be primal or dual)')
#
# # optimize
# bar = trange(tot_iter)
# opt = torch.optim.SGD(params, lr=lr)
# for _ in bar:
# opt.zero_grad()
# loss = fun(self, x=inputs)
# loss.backward(retain_graph=True)
# opt.step()
# bar.set_description(f"{loss:1.2e}")
#
# return inputs