Source code for kerch.method._smoother

# coding=utf-8
from __future__ import annotations

import torch
from typing import Union

from ..utils import DEFAULT_KERNEL_TYPE
from kerch.utils import castf


[docs] @torch.no_grad() def smoother(coefficients: torch.Tensor, observations: torch.Tensor, num: Union[int, str] = 'all') -> torch.Tensor: r""" Returns a weighted sum of the observations by the coefficients. .. math:: \mathtt{out}_{i,j} = \frac{\sum_l\mathtt{weights}_{i,l} * \mathtt{observations}_{l,j}}{\sum_l\mathtt{weights}_{i,l}}. :param coefficients: coefficients used in the smoother. :type coefficients: torch.Tensor [num_points, num_observations] :param observations: observation corresponding to each weight dimension. :type observations: torch.Tensor [num_observations, dim_observations] :param num: Number of closest points to be used. Either an integer representing the number or the string ``'all'``. Defaults to ``'all'``. :type num: int or str, optional :return: Weighted observations :rtype: torch.Tensor [num_points, dim_observations] """ # PRELIMINARIES coefficients = castf(coefficients) observations = castf(observations) num_points, num_coefficients = coefficients.shape num_observations = observations.shape[0] # DEFENSIVE assert num_coefficients == num_observations, \ f'Smoother: Incorrect number of coefficients ({num_coefficients}), ' \ f'compared to the number of points observations points({num_observations}).' if isinstance(num, str): assert num.lower() == 'all', \ f"Only the string value of 'all' is allowed as value for the argument num ({num})." \ f"Please use an integer is you prefer specifying a specific number." num = num_coefficients else: try: num = int(num) except ValueError: raise ValueError('The argument num is not an integer.') assert num > 0, \ f"The argument num ({num}) must be a strictly positive integer (num > 0)." assert num <= num_coefficients, \ f"The argument num ({num}) exceeds the number of observations ({num_observations})." # PRE-IMAGE if num == num_coefficients: kept_coeff = coefficients else: vals, indices = torch.topk(coefficients, k=num, dim=1, largest=True) kept_coeff = torch.zeros_like(coefficients) kept_coeff.scatter_(1, indices, vals) normalized_coeff = kept_coeff / torch.sum(kept_coeff, dim=1, keepdim=True) return normalized_coeff @ observations
[docs] @torch.no_grad() def kernel_smoother(domain: torch.Tensor, observations: torch.Tensor, num: Union[int, str] = 'all', kernel_type: str=DEFAULT_KERNEL_TYPE, **kwargs) -> torch.Tensor: r""" Smoothens the function :math:`f(\mathtt{domain}_{i,:}) = \mathtt{observation}_{i,:}` with weights defined as a kernel on the domain. .. math:: \mathtt{out}_{i,j} = \frac{\sum_l k\left(\mathtt{domain}_{i,:},\mathtt{domain}_{l,:}\right) * \mathtt{observations}_{l,j}}{\sum_l k\left(\mathtt{domain}_{i,:},\mathtt{domain}_{l,:}\right)}. The kernel is defined as in :py:func:`kerch.kernel.factory`. :param domain: domain corresponding to each observation. :type domain: torch.Tensor [num_observations, dim_domain] :param observations: observation corresponding to each domain entry. :type observations: torch.Tensor [num_observations, dim_observations] :param num: Number of closest points to be used. Either an integer representing the number or the string ``'all'``. Defaults to ``'all'``. :type num: int or str, optional :param kernel_type: Type of kernel chosen. For the possible choices, please refer to the `Factory Type` column of the :doc:`../kernel/index` documentation. Defaults to :py:data:`kerch.DEFAULT_KERNEL_TYPE`. :param \**kwargs: Arguments to be passed to the kernel constructor, such as `sample` or `sigma`. If an argument is passed that does not exist (e.g. `sigma` to a `linear` kernel), it will just be neglected. For the default values, please refer to the default values of the requested kernel. :type kernel_type: str, optional :type \**kwargs: dict, optional :return: Smoothed function :math:`f` according to kernel :math:`k`. :rtype: torch.Tensor [num_observations, dim_observations] """ domain = castf(domain) observations = castf(observations) assert domain.shape[0] == observations.shape[0], f"Not the same number of domain {domain.shape[0]} and coefficients points {domain.shape[0]}." from ..kernel import factory k = factory(kernel_type=kernel_type, sample=domain, **kwargs) return smoother(coefficients=k.K, observations=observations, num=num)