Source code for kerch.transform.TransformTree
# coding=utf-8
from typing import Union, List
import kerch
import torch
from torch import Tensor
from torch.nn import Parameter
from .Transform import Transform
from .all import (UnitSphereNormalization,
MinimumCentering,
MeanCentering,
UnitVarianceNormalization,
MinMaxNormalization)
[docs]
@kerch.utils.extend_docstring(Transform)
class TransformTree(Transform):
r"""
Creates a tree of transform for efficient computing, with cache management.
:param explicit: True is the transform are to be computed in the explicit formulation, False instead.
:param sample: Default sample on which the statistics are to be computed.
:param default_transform: Optional default list of transform., defaults to None.
:param diag_fun: Optional function handle to directly compute the diagonal of the implicit formulation to increase
computation speed.
:type explicit: bool
:type sample: Tensor or function handle
:type default_transform: List[str]
:type diag_fun: Function handle
"""
_all_transform = {"normalize": UnitSphereNormalization, # for legacy
"center": MeanCentering, # for legacy
"sphere": UnitSphereNormalization,
"min": MinimumCentering,
"variance": UnitVarianceNormalization,
"standard": [MeanCentering, UnitVarianceNormalization],
"minmax": MinMaxNormalization,
"unit_sphere_normalization": UnitSphereNormalization,
"mean_centering": MeanCentering,
"minimum_centering": MinimumCentering,
"unit_variance_normalization": UnitVarianceNormalization,
"minmax_normalization": MinMaxNormalization,
"standardize": [MeanCentering, UnitVarianceNormalization],
"minmax_rescaling": [MinimumCentering, MinMaxNormalization]}
[docs]
@staticmethod
def beautify_transform(transform: list[str]) -> Union[None, List[Transform]]:
r"""
Creates a list of _Transform classes and removes duplicates.
:param transform: list of the different transform.
:type transform: List[str]
"""
if transform is None:
return None
else:
transform_classes = []
for tr in transform:
try:
if issubclass(tr, Transform):
transform_classes.append(tr)
except TypeError:
new_transform = TransformTree._all_transform.get(
tr, NameError(f"Unrecognized transform key {tr}."))
if isinstance(new_transform, Exception):
raise new_transform
elif isinstance(new_transform, List):
for ntr in new_transform:
transform_classes.append(ntr)
elif issubclass(new_transform, Transform):
transform_classes.append(new_transform)
else:
kerch._GLOBAL_LOGGER._logger.error("Error while creating TransformTree list of transform")
# remove same following elements
previous_item = None
idx = 0
for current_item in transform_classes:
if current_item == previous_item:
transform_classes.pop(idx)
else:
previous_item = current_item
idx += 1
return transform_classes
def __init__(self, explicit: bool, sample, default_transform=None, diag_fun=None, **kwargs):
super(TransformTree, self).__init__(explicit=explicit, name='base', **kwargs)
if default_transform is None:
default_transform = []
self._default_transforms = TransformTree.beautify_transform(default_transform)
# create default tree
node = self
for transform in self._default_transforms:
offspring = transform(explicit=self.explicit, default_path=True, cache_level=self.cache_level)
node.add_offspring(offspring)
node = offspring
self._default_node = node
node.default = True
self._base = sample
self._data_oos = None
self._diag_fun = diag_fun
def __str__(self):
output = "Transforms: \n"
if len(self._default_transforms) == 0:
return output + "\t" + "None (default)"
node = self._default_node
while not isinstance(node, TransformTree):
output += "\t" + str(node)
node = node.parent
return output
@property
def default_transforms(self) -> List:
r"""
Default list of transforms to be applied.
"""
return self._default_transforms
@property
def final_transform(self) -> type(Transform):
r"""
Final transform to be applied, which is the last element of
:py:attr:`~kerch.transform.TransformTree.default_transforms`.
"""
try:
return self._default_transforms[-1]
except IndexError:
return TransformTree
def _get_data(self) -> Union[Tensor, Parameter]:
if callable(self._base):
return self._base()
return self._base
def _explicit_statistics(self, sample):
return None
def _implicit_statistics(self, sample, x=None):
return None
def _explicit_sample(self):
if callable(self._base):
return self._base()
return self._base
def _implicit_sample(self):
if callable(self._base):
return self._base()
return self._base
def _implicit_diag(self, x=None) -> Union[Tensor]:
if self._diag_fun is not None:
return self._diag_fun(x)
if x is None:
return torch.diag(self._implicit_sample())[:, None]
else:
return torch.diag(self._implicit_oos(x, x))[:, None]
@property
def projected_sample(self) -> Tensor:
r"""
Sample after transform. Retrieved from cache if relevant.
"""
return self._default_node.sample
def _explicit_statistics_oos(self, oos, x=None):
pass
def _implicit_statistics_oos(self, oos, x=None):
pass
def _explicit_oos(self, x=None):
if callable(self._base):
return self._base(x)
return x
def _implicit_oos(self, x=None, y=None):
if callable(self._base):
return self._base(x, y)
raise NotImplementedError
def _revert_explicit(self, oos):
return oos
def _revert_implicit(self, oos):
return oos
def _get_tree(self, transform: List[str] = None) -> List[Transform]:
transform = TransformTree.beautify_transform(transform)
if transform is None:
transform = self._default_transforms
tree_path = [self]
for tr_class in transform:
current_tr = tree_path[-1]
if tr_class in current_tr.offspring:
offspring = current_tr.offspring[tr_class]
else:
offspring = tr_class(explicit=self.explicit, cache_level=self.cache_level)
current_tr.add_offspring(offspring)
tree_path.append(offspring)
return tree_path
[docs]
def apply(self, oos=None, x=None, y=None, transform: List[str] = None) -> Tensor:
r"""
Applies the transform to the value to out-of-sample data. Either value is a function handle and you can use
x (explicit) and x, y (explicit) to specify the data. Either directly give a Tensor.
.. warning::
If value is a Tensor, some transform may not work in implicit formulation. For example, the unit sphere
normalization requires k(x,x) for all out-of-sample points. Some combinations may be even more intricate.
:param x: Relevant if using a function handle for value.
:param y: Relevant if using a function handle for value in implicit mode.
:param transform: Transforms to be used. If none are to be used, i.e. getting the raw data back, please
specify [], not None, which will return the default transform used for the sample., defaults to None,
i.e., the default transform.
:type x: Tensor
:type y: Tensor
:type transform: List[str]
"""
tree_path = self._get_tree(transform)
sol = tree_path[-1].oos(x=x, y=y)
self._clean_cache()
return sol
[docs]
def revert(self, value, transform: List[str] = None) -> Tensor:
r"""
Reverts the transform (runs the tree backwards) to the value to out-of-sample data.
:param value: Out-of-sample data.
:param transform: Transforms to be used. If none are to be used, i.e. getting the raw data back, please
specify [], not None, which will return the default transform used for the sample., defaults to None,
i.e., the default transform.
:type value: Tensor
:type transform: List[str]
"""
tree_path = self._get_tree(transform)
for transform in reversed(tree_path):
value = transform._revert(value)
return value