Source code for kerch.method._knn

# coding=utf-8
import torch
from ..utils import castf, DEFAULT_KERNEL_TYPE
from ..feature.logger import _GLOBAL_LOGGER


[docs] @torch.no_grad() def knn(dists: torch.Tensor, observations: torch.Tensor, num: int = 1) -> torch.Tensor: r""" For each distance ``dists``, returns the average of the ``num`` smallest corresponding observations. :param dists: coefficients used in the knn. :type dists: torch.Tensor [num_points, num_observations] :param observations: observation corresponding to each weight dimension. :type observations: torch.Tensor [num_observations, dim_observations] :param num: number of nearest neighbors. Defaults to 1. :type num: int, optional :return: KNN :rtype: torch.Tensor [num_points, dim_observations] """ # PRELIMINARIES dists = castf(dists) observations = castf(observations) num_points, num_coefficients = dists.shape num_observations = observations.shape[0] # DEFENSIVE try: num = int(num) except ValueError: raise ValueError('The argument num is not an integer.') assert num_coefficients == num_observations, \ f'KNN: Incorrect number of coefficients ({num_coefficients}), ' \ f'compared to the number of points ({num_observations}).' assert num <= num_coefficients, \ (f"Too much required neighbors ({num}) compared to the number of observations points " f"({num_observations}). Please insure that the argument num is not greater than the number of observations " f"points.") assert num > 0, \ f"The number of required neighbors num must be strictly positive ({num})." # PRE-IMAGE if dists.min() >= 0: _GLOBAL_LOGGER._logger.warning('There are negative distances for kNN. The coefficients are changed.') dists = dists - dists.min() _, indices = torch.topk(-dists, k=num, dim=1) kept_sample = observations[indices] return torch.mean(kept_sample, dim=1)
[docs] @torch.no_grad() def kernel_knn(domain: torch.Tensor, observations: torch.Tensor, num: int = 1, kernel_type: str = DEFAULT_KERNEL_TYPE, **kwargs) -> torch.Tensor: r""" For each coefficient, returns the average of the ``num`` greatest corresponding kernel values on the domain. The kernel is defined as in :py:func:`kerch.kernel.factory`. :param domain: domain corresponding to each observation. :type domain: torch.Tensor [num_observations, dim_domain] :param observations: observation corresponding to each domain entry. :type observations: torch.Tensor [num_observations, dim_observations] :param num: number of nearest neighbors. Defaults to 1. :type num: int, optional :param kernel_type: Type of kernel chosen. For the possible choices, please refer to the `Factory Type` column of the :doc:`../kernel/index` documentation. Defaults to :py:data:`kerch.DEFAULT_KERNEL_TYPE`. :param \**kwargs: Arguments to be passed to the kernel constructor, such as `sample` or `sigma`. If an argument is passed that does not exist (e.g. `sigma` to a `linear` kernel), it will just be neglected. For the default values, please refer to the default values of the requested kernel. :type kernel_type: str, optional :type \**kwargs: dict, optional :return: KNN :rtype: torch.Tensor [num_points, dim_observations] """ domain = castf(domain) observations = castf(observations) assert domain.shape[0] == observations.shape[ 0], f"Not the same number of domain {domain.shape[0]} and coefficients points {domain.shape[0]}." from ..kernel import factory k = factory(kernel_type=kernel_type, sample=domain, **kwargs) return knn(dists=-k.K, observations=observations, num=num)