Source code for kerch.kernel.vision.skewed_chi_2

# coding=utf-8
"""
File containing the RBF kernel class.

@author: HENRI DE PLAEN
@copyright: KU LEUVEN
@license: MIT
@date: May 2022
"""
from typing import Iterator
import torch

from ... import utils
from ..implicit import Implicit, Kernel



[docs] @utils.extend_docstring(Kernel) class SkewedChi2(Implicit): r""" Skewed Chi Squared kernel. Often used in computer vision. .. math:: k(x,y) = \prod_i \frac{2\sqrt{x_i+p} \sqrt{y_i+p}}{x_i + y_i + 2}. :param p: Free parameter :math:`p`., defaults to 0. :param p_trainable: `True` if the gradient of :math:`p` is to be computed. If so, a graph is computed and :math:`p` can be updated. `False` just leads to a static computation., defaults to `False` :type p: float, optional :type p_trainable: bool, optional """ def __init__(self, *args, **kwargs): self._p = kwargs.pop('p', 0.) super(SkewedChi2, self).__init__(*args, **kwargs) self._p_trainable = kwargs.pop('p_trainable', False) self._p = torch.nn.Parameter( torch.tensor(self._p, dtype=utils.FTYPE), requires_grad=self._p_trainable) @property def p(self) -> float: r""" Parameter :math:`p` of the kernel. """ if isinstance(self._p, torch.nn.Parameter): return self._p.data.cpu().numpy().astype(float) return float(self._p) @p.setter def p(self, val): self._reset_cache(reset_persisting=False) self._p.data = utils.castf(val, tensor=False, dev=self._p.device) def __str__(self): return f"Skewed Chi Squared kernel (p: {self.p})." @property def hparams_variable(self): return {'Kernel parameter p': self.p} @property def hparams_fixed(self): return {"Kernel": "Skewed Chi Squred", "Trainable p": self._p_trainable, **super(SkewedChi2, self).hparams_fixed} def _implicit(self, x, y): x = x.T[:, :, None] y = y.T[:, None, :] prod = torch.sqrt(x + self._p) * torch.sqrt(y + self._p) sum = torch.clamp(x + y + 2 * self._p, min=utils.EPS) output = torch.prod(2 * prod / sum, dim=0, keepdim=True) return output.squeeze(0) def _slow_parameters(self, recurse=True) -> Iterator[torch.nn.Parameter]: yield self._p yield from super(SkewedChi2, self)._slow_parameters(recurse)