Source code for kerch.kernel.distance.distance_squared

from __future__ import annotations
from abc import ABCMeta, abstractmethod
import torch

from ... import utils
from ._distance import _Distance


[docs] @utils.extend_docstring(_Distance) class DistanceSquared(_Distance, metaclass=ABCMeta): def __init__(self, *args, **kwargs): super(DistanceSquared, self).__init__(*args, **kwargs) def _sample_square_dist(self, destroy=False) -> torch.Tensor: return self._get(key="_kernel_square_dist_sample", default_level='total', force=True, destroy=destroy, fun=lambda: self._square_dist(self.current_sample_projected, self.current_sample_projected)) def _determine_sigma(self) -> None: if not self._sigma_defined: with torch.no_grad(): d = self._sample_square_dist() sigma = .5 * torch.sqrt(torch.median(d)) self.sigma = sigma self._logger.warning(f"The kernel bandwidth sigma has not been provided and is assigned by a " f"heuristic (sigma={self.sigma:.2e}).") def _dist_sigma(self, x, y): return torch.sqrt(self._square_dist_sigma(x, y)) def _square_dist_sigma(self, x, y): _ = self.sigma if id(x) == id(y) and id(x) == id(self.current_sample_projected): d = self._sample_square_dist(destroy=True) else: d = self._square_dist(x, y) return self._sigma_fact ** 2 * d @abstractmethod def _square_dist(self, x, y) -> torch.Tensor: pass def _dist(self, x, y) -> torch.Tensor: return torch.sqrt(self._square_dist(x, y))