Cache Management

When performing various operations on the Kerch modules, it may happen that some results are necessary multiple times. Some of these operations are however expensive to compute and it would be ideal to avoid recomputation and load them from memory when already computed previously. This justifies the addition of a cache manager. The purpose of the kerch.feature.Cache class is to extend the kerch.feature.Module class with a cache manager. To illustrate its relevance, we can consider two example use cases:

  • Kernel Matrix:

    Due to its quadratic complexity, computing the kernel matrix may be very expensive, in particular if not computed through explicit feature maps. Let us suppose that the kernel matrix has already been computed in order to be plotted for example. If one wants to then compute its eigendecomposition for KPCA, the matrix is reloaded from memory and not computed a second time.

  • Data Transformations:

    We consider a big sample dataset that we want to be centered and normalized. When working with out-of-sample datasets, these have to be centered using the same statistics as sample in order to keep the model consistent. These statistics can be stored and re-used every time computations have to be performed on out-of-sample datasets.

In other words, this is a way to bypass the garbage-collector, but with a lot of granularity.

Functionalities

Automatic Getter and Saver

Essentially most of the work is done through the _get() method that allows to check in one line if the result of an operation (a kernel matrix, sample statistics…) have already been computed and return it. If not computed, the information is stored in the cache, ready for further re-use (through the _save() method that is automatically called by _get() in if not already present in the cache).

Cache Levels

It would be totally inefficient to store everything in the cache. Therefore, different cache levels exist meant for storing different values. Each value is assigned to a specific value when computed for the first time. Each module has a internal cache level (cache_level) that serves as a threshold when new values are computed. If the specified cache level of the newly computed result exceeds the default cache level of the module, the information is not saved. The next time that the same computation is required, it will thus not be loaded from the cache, but computed again. We distinguish the following cache levels:

  • "none": the cache is non-existent and everything is computed on the go. This is the lightest for the memory.

  • "light": the cache is very light. For example, only the kernel matrix and statistics of the sample points are saved.

  • "normal": same as light, but the statistics of the out-of-sample points are also saved, not the kernel matrices.

  • "heavy": in addition to the statistics, the final kernel matrices of the out-of-sample points are saved.

  • "total": every step of any computation is saved. This is very heavy on the memory.

The higher the cache level, the more will be stored into memory and the less redundancy will be introduced in the need for recomputations. The module’s cache_level attribute therefore controls a time versus memory trade-off.

Resetter

If the sample changes for example, most of the cache entries require te be recomputed. The two private methods _reset_cache() and _clean_cache() therefore exists to reset and clean the cache. This is done automatically when necessary. In practice, unless wanting to tweak the package’s internal working, these methods should not be called by the user. If the user really wants to reset the cache, he may use the reset() method. He may also visualize the current cache entries by calling the print_cache() method or the cache_keys() method to retrieve the keys of the cached values.

Default Cache Levels

For each possible value to be stored, default cache levels are saved defined in kerch.DEFAULT_CACHE_LEVELS. These values can be changed if the user wants to customize the granularity. Here follows a summary.

Sample-Specific

The Transformation Tree refers to an instance of kerch.transform.TransformTree and does not hold any data in itself apart from the successive operations in themselves to perform the transformations, i.e., the tree of the transformations. We refer further the cache levels inside the tree.

Key

Description

'none'

'light'

'normal'

'heavy'

'total'

"sample_transform"

Sample Transformation Tree

Kernel-Specific

This is relevant for all classes who inherit from kerch.kernel.Kernel. The attributes K, Phi and C are always saved once computed until the sample changes. The Transformation Trees refer to instances kerch.transform.TransformTree (one for the explicit and one for the implicit) and do not hold any data in themselves apart from the successive operations themselves to perform the transformations, i.e., the tree of the transformations. We refer further the cache levels inside the trees.

Key

Description

'none'

'light'

'normal'

'heavy'

'total'

"sample_phi"

Explicit feature map of the sample

"sample_C"

Explicit matrix of the sample

"sample_K"

Kernel matrix of the sample

"kernel_explicit_transform"

Explicit Transformation Tree

"kernel_implicit_transform"

Implicit Transformation Tree

Transformation-Specific

The tree itself contains the transformations in themselves. These values refer to which values are stored inside a transformation tree instance kerch.transform.TransformTree. The tree can store both the statistics (average, variance, minimum…) required to perform the transformations and the transformed values themselves (centered valued, normalized values…). De default refers to the default transformation. We refer to the documentation of Data and Kernel Transformations for further information.

Key

Description

'none'

'light'

'normal'

'heavy'

'total'

"transform_sample_data_default"

Transformed sample value of the default transformation

"transform_sample_data_nondefault"

Transformed sample value of another transformation

"transform_sample_statistics_default"

Transformed sample statistics of the default transformation

"transform_sample_statistics_nondefault"

Transformed sample statistics of another transformation

"transform_oos_data_default"

Transformed out-of-sample value of the default transformation

"transform_oos_data_nondefault"

Transformed out-of-sample value of another transformation

"transform_oos_statistics_default"

Transformed out-of-sample statistics of the default transformation

"transform_oos_statistics_nondefault"

Transformed out-of-sample statistics of another transformation

Level-Specific

The output of a level is also saved by default, until the sample or the model parameters change. The default representation refers to primal or dual. We refer to the documentation of the Level Module for further information. Many models require an identity matrix to solve the model. This matrix can be stored for further usage, unless the dimensions change. The different constituents of the loss (regularization term, recontruction term…), referred to as sublosses, are saved independently from the total loss for monitoring. These are also resetted once the before_step() method is called.

Key

Description

'none'

'light'

'normal'

'heavy'

'total'

"forward_sample_default_representation"

Output value of the sample in the default representation

"forward_sample_other_representation"

Output value of the sample in the other representation

"forward_oos_default_representation"

Output value of an out-of-sample in the default representation

"forward_oos_other_representation"

Output value of an out-of-sample in the other representation

"Level_I_default_respresentation"

Identity matrix in the default representation dimension

"Level_I_other_respresentation"

Identity matrix in the other representation dimension

"Level_subloss_default_respresentation"

Individual sublosses in the default representation

"Level_subloss_other_respresentation"

Individual sublosses in the other representation

Abstract Class

class kerch.feature.Cache(*args, **kwargs)[source]

Bases: Module

Parameters:
  • cache_level (str, optional) – Cache level for saving temporary execution results during the execution. The higher the cache, the more is saved. Defaults to 'normal'. We refer to the Cache Management documentation for further information.

  • logging_level (int, optional) – Logging level for this specific instance. If the value is None, the current default kerch global log level will be used. Defaults to None (default kerch logging level). We refer to the Logging in Kerch documentation for further information.

_apply(fn, recurse=True)[source]

This if the native function by torch.nn.modules.module.Module, used when porting the module. This ensures that the cache is also ported. This is used for example to port the data to the GPU or CPU.

Note

This method is documented for completeness, but it should never be required to call it directly.

_clean_cache(max_level: str | int | None = None)[source]

Cleans all cache elements above a certain level. This is relevant for cleaning the cache elements that have been forced (see _save()).

Parameters:

max_level (str | int, optional) – all levels above this level will be cleaned, max_level excluded. Defaults to the default cache level.

Note

This method is documented for completeness, but it should never be required to call it directly.

_get(key, fun=None, level_key=None, default_level: str = 'normal', force: bool = False, overwrite: bool = False, persisting=False, destroy=False) Any[source]

Retrieves an element from the cache. If the element is not present, it saved to the cache provided its level is lower or equal to the default level. This can be overwritten by the overwrite argument.

Parameters:
  • key (str) – key of the cache element.

  • fun (function handle) – function to compute the element if not in the cache already.

  • level_key (str, optional) – key referencing the default level to use in kerch.DEFAULT_CACHE_LEVELS. If not specified, the default_level argument is used.

  • default_level (str, optional) – level where to save the cache element. Defaults to ‘normal’.

  • force (bool, optional) – if the value is True, the element will nevertheless be saved whatever level is specified. Defaults to False.

  • persisting (bool, optional) – These values are meant to persist after a cache reset when calling reset() with reset_persisting=False. Defaults to False.

  • destroy (bool, optional) – This destroys the value from the cache after being read/computed. This is meant for short-term memory. Defaults to False.

Returns:

The result of fun()

_remove_from_cache(key: str | List[str]) None[source]

Removes one or more specific element(s) from the cache.

Parameters:

key (str | list[str]) – Key(s) of the cache element to be removed.

Note

This method is documented for completeness, but it should never be required to call it directly.

_reset_cache(reset_persisting: bool = True, avoid_classes: list | None = None) None[source]

This just resets the cache and makes it empty.

Parameters:
  • reset_persisting (bool, optional) – Persisting elements are meant to resist to a cache reset (see _save()). The option allows to also reset them if True. Defaults to True.

  • avoid_classes (list(type(Cache)), optional) – Class of which the elements must be avoided to be resetted. Default to [].

Note

This method is documented for completeness, but it should never be required to call it directly.

_save(key, fun, level_key=None, default_level: str = 'total', force: bool = False, persisting=False) Any[source]

Saves an element in the cache.

Parameters:
  • key (str) – key of the cache element.

  • fun (function handle) – function to compute the element if not in the cache already.

  • level_key (str, optional) – key referencing the default level to use in kerch.DEFAULT_CACHE_LEVELS. If not specified, the default_level argument is used.

  • default_level (str, optional) – level where to save the cache element. Defaults to ‘total’.

  • force (bool, optional) – if the value is True, the element will nevertheless be saved whatever level is specified. Defaults to False.

  • persisting (bool, optional) – These values are meant to persist after a cache reset when calling reset() with reset_persisting=False. Defaults to False.

Returns:

The result of fun()

cache_keys(private: bool = False) Iterable[str][source]

Returns an iterable containing the different cache keys. We refer to the Cache Management documentation for more information.

Parameters:

private (bool, optional) – Some cache elements are private and are not returned unless set to True. Defaults to False.

property cache_level: str

Cache level for saving temporary execution results during the execution. The higher the cache, the more is saved. Defaults to 'normal' unless set otherwise during instantiation. The different possible values are:

  • "none": the cache is non-existent and everything is computed on the go.

  • "light": the cache is very light. For example, only the kernel matrix and statistics of the sample points are saved.

  • "normal": same as light, but the statistics of the out-of-sample points are also saved.

  • "heavy": in addition to the statistics, the final kernel matrices of the out-of-sample points are saved.

  • "total": every step of any computation is saved.

We refer to the Cache Management documentation for further information.

print_cache(private: bool = False) None[source]

Prints the cache content. We refer to the Cache Management documentation for further information.

Parameters:

private (bool, optional) – Some cache elements are private and are not returned unless set to True. Defaults to False.

reset(recurse=False, reset_persisting=True) None[source]

Resets the cache to be empty. We refer to the Cache Management documentation for more information.

Parameters:
  • recurse (bool, optional) – If True, resets the cache of this module and also of its potential children. otherwise, it only resets the cache for this module. Defaults to True.

  • reset_persisting (bool, optional) – Persisting elements are meant to resist to a cache reset (see _save()). The option allows to also reset them if True. Defaults to True.

Examples

For a proper usage of the Kerch package, there is no need to manage the cache. For the sake of completeness, we however provide two examples. The first one shows how the cache works on an existing implementation. The second example shows how one can manage cache elements by itself.

KPCA

The following example illustrates the working of the cache. We will consider two examples of a kerch.level.KPCA, one with a light cache level and another with a total cache level.

import kerch
import torch

torch.manual_seed(0)

sample = torch.randn(5, 3)
oos = torch.randn(2, 3)

kpca_light_cache = kerch.level.KPCA(sample=sample,                  # random sample
                                    dim_output=2,                   # we want an output dimension of 2 (the input is 3)
                                    sample_transform=['min'],       # we want the input to be normalized (based on the statistics of the sample)
                                    kernel_transform=['center'],    # we want the kernel to be center
                                    cache_level='light')            # a 'light' cache level (only related to the sample)

kpca_total_cache = kerch.level.KPCA(sample=sample,                  # idem
                                    dim_output=2,                   # idem
                                    sample_transform=['min'],       # idem
                                    kernel_transform=['center'],    # idem
                                    cache_level='total')            # a 'total' cache level (saves everything)

If we plot the cache now, nothing is printed: the cache is empty. The advantage of this package is that it does not perform any unnecessary computations. Depending of what is required, it will compute what is strictly necessary.

kpca_light_cache.print_cache()
kernel_implicit_transform [none] : Transforms: 
	Mean centering (default)

KERNEL_IMPLICIT_TRANSFORM:

After solving the model and passing an out-of-sample through the model, we can see that the cache is pretty much loaded. Nothing however has been saved on related to the out-of-sample, even if it has been computed. This is a consequence of light cache level of the module. Similarly, nor has the original non-centered kernel matrix, nor has the original non-transformed sample been saved.

kpca_light_cache.solve()
kpca_light_cache.forward(oos)
kpca_light_cache.print_cache()
sample_transform [none] : Transforms: 
	Minimum Centering (default)
K [light] : tensor([[ 0.7424,  0.0142, -0.2422, -0.2459, -0.2685],
        [ 0.0142,  0.6743, -0.2509, -0.1506, -0.2870],
        [-0.2422, -0.2509,  0.6180, -0.1665,  0.0415],
        [-0.2459, -0.1506, -0.1665,  0.7539, -0.1910],
        [-0.2685, -0.2870,  0.0415, -0.1910,  0.7049]])
kernel_implicit_transform [none] : Transforms: 
	Mean centering (default)

SAMPLE_TRANSFORM:
Minimum Centering statistics_sample [light] : tensor([-0.8567, -1.0845, -2.1788])
Minimum Centering data_sample [light] : tensor([[2.3977, 0.7911, 0.0000],
        [1.4251, 0.0000, 0.7802],
        [1.2600, 1.9225, 1.4595],
        [0.4533, 0.4879, 2.3608],
        [0.0000, 2.1851, 1.1076]])

KERNEL_IMPLICIT_TRANSFORM:
Mean centering statistics_sample [light] : (tensor([[0.2794],
        [0.3135],
        [0.3416],
        [0.2737],
        [0.2982]]), tensor(0.3013))

We can optionally reset the cache. Nothing is printed: the cache is empty again.

kpca_light_cache.reset()
kpca_light_cache.print_cache()

We can now have a look at the 'total' version. Again, before anything is required, the cache remains empty: nothing has computed yet.

kpca_total_cache.print_cache()
kernel_implicit_transform [none] : Transforms: 
	Mean centering (default)

KERNEL_IMPLICIT_TRANSFORM:

We now see that everything is saved.

kpca_total_cache.solve()
kpca_total_cache.forward(oos)
kpca_total_cache.print_cache()
sample_transform [none] : Transforms: 
	Minimum Centering (default)
K [light] : tensor([[ 0.7424,  0.0142, -0.2422, -0.2459, -0.2685],
        [ 0.0142,  0.6743, -0.2509, -0.1506, -0.2870],
        [-0.2422, -0.2509,  0.6180, -0.1665,  0.0415],
        [-0.2459, -0.1506, -0.1665,  0.7539, -0.1910],
        [-0.2685, -0.2870,  0.0415, -0.1910,  0.7049]])
kernel_implicit_transform [none] : Transforms: 
	Mean centering (default)
forward_140514613185136_dual [normal] : tensor([[-0.0591,  0.6962],
        [-0.0009,  0.4885]])

SAMPLE_TRANSFORM:
base data_sample [normal] : tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820],
        [-0.8567,  1.1006, -1.0712]])
base data_oos_140514613185136 [normal] : tensor([[ 0.1227, -0.5663,  0.3731],
        [-0.8920, -1.5091,  0.3704]])
Minimum Centering statistics_sample [light] : tensor([-0.8567, -1.0845, -2.1788])
Minimum Centering data_sample [light] : tensor([[2.3977, 0.7911, 0.0000],
        [1.4251, 0.0000, 0.7802],
        [1.2600, 1.9225, 1.4595],
        [0.4533, 0.4879, 2.3608],
        [0.0000, 2.1851, 1.1076]])
Minimum Centering statistics_oos_140514613185136 [normal] : tensor([-0.8567, -1.0845, -2.1788])
Minimum Centering data_oos_140514613185136 [normal] : tensor([[ 0.9794,  0.5182,  2.5519],
        [-0.0353, -0.4246,  2.5492]])

KERNEL_IMPLICIT_TRANSFORM:
base data_oos_140514611875280 [normal] : tensor([[9.3538e-03, 1.4093e-01, 1.7157e-01, 8.4308e-01, 4.2237e-02],
        [5.2590e-04, 5.1963e-02, 1.0564e-02, 5.4803e-01, 7.9825e-03]])
base data_sample [normal] : tensor([[1.0000, 0.3058, 0.0776, 0.0059, 0.0079],
        [0.3058, 1.0000, 0.1029, 0.1353, 0.0234],
        [0.0776, 0.1029, 1.0000, 0.1476, 0.3801],
        [0.0059, 0.1353, 0.1476, 1.0000, 0.0796],
        [0.0079, 0.0234, 0.3801, 0.0796, 1.0000]])
Mean centering statistics_sample [light] : (tensor([[0.2794],
        [0.3135],
        [0.3416],
        [0.2737],
        [0.2982]]), tensor(0.3013))
Mean centering statistics_oos_140514611875280 [normal] : (tensor([[0.2414],
        [0.1238]]), tensor(0.3013))
Mean centering data_oos_140514611875280 [normal] : tensor([[-0.2102, -0.1127, -0.1102,  0.6292, -0.1961],
        [-0.1014, -0.0841, -0.1536,  0.4518, -0.1127]])

We can optionally reset the cache again. Nothing is printed: the cache is empty again.

kpca_total_cache.reset()
kpca_total_cache.print_cache()

Managing the Cache

In the following example, we show how we can add an element to the cache and recover it when called.

import kerch
import torch
import time

class MyCacheExample(kerch.feature.Cache):
    def __init__(self, *args, **kwargs):
        super(MyCacheExample, self).__init__(*args, **kwargs)
        self.big_matrix = kwargs.pop('big_matrix')

    def _compute_qr(self):
        def qr_fun():
            return torch.linalg.qr(self.big_matrix)
        return self._get(key='qr', fun=qr_fun)

    @property
    def Q(self):
        q, r = self._compute_qr()
        return q

    @property
    def R(self):
        q, r = self._compute_qr()
        return r

# we instantiate our new class
m = torch.randn(200, 100)
my_example = MyCacheExample(big_matrix=m)

# we time our Q property
start = time.time()
my_example.Q
end = time.time()
print('First access: ' + str(end-start), end='\n\n')

# we time it again
start = time.time()
my_example.Q
end = time.time()
print('Second access: ' + str(end-start), end='\n\n')

# we now have a look at our cache
my_example.print_cache()
First access: 0.0011394023895263672

Second access: 5.0067901611328125e-06

qr [normal] : torch.return_types.linalg_qr(
Q=tensor([[-2.2830e-02,  8.1899e-03,  4.3683e-02,  ...,  1.6254e-01,
         -7.8151e-02, -7.4254e-02],
        [ 8.8169e-02, -1.3530e-03, -3.0335e-02,  ...,  2.4953e-02,
         -8.5254e-02,  2.3078e-02],
        [ 4.5048e-03,  9.3206e-02, -6.3898e-02,  ...,  4.8213e-02,
          2.2990e-02, -5.4721e-02],
        ...,
        [-4.9373e-02,  7.2460e-03, -8.7388e-02,  ...,  1.2602e-02,
          5.2047e-02, -2.1568e-01],
        [-6.8546e-02, -1.3272e-01,  7.9209e-02,  ...,  4.2796e-03,
         -4.8783e-02, -2.0896e-02],
        [-1.0969e-02,  8.1132e-02,  7.9921e-02,  ..., -3.9294e-02,
         -5.0511e-05,  1.1770e-02]]),
R=tensor([[-14.0207,  -0.7918,   1.0278,  ...,   1.2658,  -0.4110,  -1.0155],
        [  0.0000, -15.4406,   0.9448,  ...,   0.5908,  -0.5970,   0.3828],
        [  0.0000,   0.0000, -14.0512,  ...,   1.3263,  -1.0769,  -0.2321],
        ...,
        [  0.0000,   0.0000,   0.0000,  ...,   9.7327,  -0.5826,   0.8639],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   9.3808,  -0.2515],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000, -10.0514]]))

Inheritance Diagram

digraph inheritance4c60e1209d { bgcolor=transparent; fontsize=12; rankdir=TB; size="16.0, 20.0"; "kerch.feature.Cache" [URL="#kerch.feature.Cache",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip=":param cache_level: Cache level for saving temporary execution results during the execution. The higher the cache,"]; "kerch.feature.Module" -> "kerch.feature.Cache" [arrowsize=0.5,style="setlinewidth(0.5)"]; "kerch.feature.Logger" [URL="logger.html#kerch.feature.Logger",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip=":param logging_level: Logging level for this specific instance."]; "kerch.feature.Module" [URL="module.html#kerch.feature.Module",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip=":param logging_level: Logging level for this specific instance."]; "kerch.feature.Logger" -> "kerch.feature.Module" [arrowsize=0.5,style="setlinewidth(0.5)"]; "torch.nn.modules.module.Module" -> "kerch.feature.Module" [arrowsize=0.5,style="setlinewidth(0.5)"]; "torch.nn.modules.module.Module" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Base class for all neural network modules."]; }