Source code for kerch.kernel.time.indicator

# coding=utf-8
"""
File containing the indicator kernel class.

@author: HENRI DE PLAEN
@copyright: KU LEUVEN
@license: MIT
@date: May 2022
"""
from typing import Iterator
from ... import utils
from ..implicit import Implicit, Kernel

import torch



[docs] @utils.extend_docstring(Kernel) class Indicator(Implicit): r""" Indicator kernel. .. math:: k(x,y) = \left\{ \begin{array} g\gamma & \text{ if } |x-y|=0, \\ 1 & \text{ if } 0 < |x-y| \leq p, \\ 0 & \text{ otherwise.} \end{array} \right. .. note :: If the default value for :math:`\gamma` is used and the :math:`p` is to be trained, their two values will be linked. .. warning:: Depending on the choice of :math:`\gamma`, the kernel may not be positive semi-definite. The default value however ensures it, as long as the inputs are integers. If they are not, this may get more complicated. .. warning:: For this name of kernel, the input dimension of the datapoints `dim_input` must be 1. :param lag: Lag parameter :math:`p`., defaults to 1. :param gamma: Identity value :math:`\gamma` of the kernel. If `None`, the value will be :math:`\gamma = 2p+1` to ensure positive semi-definiteness., defaults to `None` :param lag_trainable: `True` if the gradient of the lag :math:`p` is to be computed. If so, a graph is computed and the lag can be updated. `False` just leads to a static computation., defaults to `False` :param gamma_trainable: `True` if the gradient of the :math:`\gamma` is to be computed. If so, a graph is computed and the :math:`\gamma` can be updated. `False` just leads to a static computation., this value will be tied to the evolution of the lag :math:`p`., defaults to `False` :type lag: double, optional :type gamma: double, optional :type lag_trainable: bool, optional :type gamma_trainable: bool, optional """ def __init__(self, *args, **kwargs): """ :param lag: bandwidth of the kernel (default 1) :param gamma: value on the diagonal (default 2 * lag + 1, which ensures PSD in most cases) """ self._lag = kwargs.pop('lag', 1) super(Indicator, self).__init__(*args, **kwargs) assert self._dim_input == 1, "The indicator kernel is only defined for 1-dimensional entries." self._lag_trainable = kwargs.pop('lag_trainable', False) self._lag = torch.nn.Parameter( torch.tensor(self._lag, dtype=utils.FTYPE), requires_grad=self._lag_trainable) self._gamma_trainable = kwargs.pop('gamma_trainable', False) gamma = kwargs.pop('gamma', None) if gamma is None: self._link_training = True self._gamma = torch.nn.Parameter(2 * self._lag.data + 1, requires_grad=False) else: self._link_training = False self._gamma = torch.nn.Parameter( torch.tensor(gamma, dtype=utils.FTYPE), requires_grad=self._gamma_trainable) def __str__(self): return f"Indicator kernel (lag: {self.lag})" @property def hparams_variable(self): return {'Kernel lag': self.lag, 'Kernel gamma': self.gamma} @property def lag(self): r""" Lah :math:`p` of the kernel. """ if isinstance(self._lag, torch.nn.Parameter): return self._lag.data.cpu().numpy() return self._lag @lag.setter def lag(self, val): self._reset_cache(reset_persisting=False) self._lag.data = utils.castf(val, tensor=False, dev=self._lag.device) @property def lag_trainable(self) -> bool: r""" Boolean indicating if the lag :math:`p` is trainable. """ return self._lag_trainable @lag_trainable.setter def lag_trainable(self, val: bool): self._lag_trainable = val self._lag.requires_grad = self._lag_trainable @property def hparams_fixed(self): return {"Kernel": "Indicator", **super(Indicator, self).hparams_fixed} @property def gamma(self): return self._gamma.data.cpu().numpy() @gamma.setter def gamma(self, val): self._reset_cache(reset_persisting=False) self._gamma.data = utils.castf(val, tensor=False, dev=self._gamma.device) def _implicit(self, x, y): if self._link_training and self.lag_trainable: self._gamma.data = 2 * self.lag + 1 x = x[:, :, None] y = y.T[:, None, :] diff = (x - y).squeeze() assert len(diff.shape) == 2, 'Indicator kernel is only defined for 1-dimensional entries.' output = (torch.abs(diff).le(self._lag)).type(dtype=utils.FTYPE) output[diff == 0] = self._gamma return output def _slow_parameters(self, recurse=True) -> Iterator[torch.nn.Parameter]: yield self._lag yield self._gamma yield from super(Indicator, self)._slow_parameters(recurse)