Cache Management
When performing various operations on the Kerch modules, it may happen that some results are necessary multiple times.
Some of these operations are however expensive to compute and it would be ideal to avoid recomputation and load them
from memory when already computed previously. This justifies the addition of a cache manager. The purpose of the
kerch.feature.Cache class is to extend the kerch.feature.Module class with a cache manager.
To illustrate its relevance, we can consider two example use cases:
- Kernel Matrix:
Due to its quadratic complexity, computing the kernel matrix may be very expensive, in particular if not computed through explicit feature maps. Let us suppose that the kernel matrix has already been computed in order to be plotted for example. If one wants to then compute its eigendecomposition for KPCA, the matrix is reloaded from memory and not computed a second time.
- Data Transformations:
We consider a big sample dataset that we want to be centered and normalized. When working with out-of-sample datasets, these have to be centered using the same statistics as sample in order to keep the model consistent. These statistics can be stored and re-used every time computations have to be performed on out-of-sample datasets.
In other words, this is a way to bypass the garbage-collector, but with a lot of granularity.
Functionalities
Automatic Getter and Saver
Essentially most of the work is done through the _get() method that allows to check in
one line if the result of an operation (a kernel matrix, sample statistics…) have already been computed and return
it. If not computed, the information is stored in the cache, ready for further re-use (through the
_save() method that is automatically called by _get() in
if not already present in the cache).
Cache Levels
It would be totally inefficient to store everything in the cache. Therefore, different cache levels exist meant for
storing different values. Each value is assigned to a specific value when computed for the first time. Each module
has a internal cache level (cache_level) that serves as a threshold when new values
are computed. If the specified cache level of the newly computed result exceeds the default cache level of the module,
the information is not saved. The next time that the same computation is required, it will thus not be loaded from the
cache, but computed again. We distinguish the following cache levels:
"none": the cache is non-existent and everything is computed on the go. This is the lightest for the memory."light": the cache is very light. For example, only the kernel matrix and statistics of the sample points are saved."normal": same as light, but the statistics of the out-of-sample points are also saved, not the kernel matrices."heavy": in addition to the statistics, the final kernel matrices of the out-of-sample points are saved."total": every step of any computation is saved. This is very heavy on the memory.
The higher the cache level, the more will be stored into memory and the less redundancy will be introduced in the need
for recomputations. The module’s cache_level attribute therefore controls a time versus
memory trade-off.
Resetter
If the sample changes for example, most of the cache entries require te be recomputed. The two
private methods _reset_cache() and _clean_cache() therefore
exists to reset and clean the cache. This is done automatically when necessary. In practice, unless wanting to tweak
the package’s internal working, these methods should not be called by the user. If the user really wants to reset the
cache, he may use the reset() method. He may also visualize the current cache entries by
calling the print_cache() method or the cache_keys()
method to retrieve the keys of the cached values.
Default Cache Levels
For each possible value to be stored, default cache levels are saved defined in kerch.DEFAULT_CACHE_LEVELS.
These values can be changed if the user wants to customize the granularity. Here follows a summary.
Sample-Specific
The Transformation Tree refers to an instance of kerch.transform.TransformTree and does not hold any data in
itself apart from the successive operations in themselves to perform the transformations, i.e., the tree of the transformations.
We refer further the cache levels inside the tree.
Key |
Description |
|
|
|
|
|
|---|---|---|---|---|---|---|
|
Sample Transformation Tree |
Kernel-Specific
This is relevant for all classes who inherit from kerch.kernel.Kernel. The attributes K,
Phi and C are always saved once computed until the sample
changes. The Transformation Trees refer to instances kerch.transform.TransformTree (one for the explicit and one for the implicit)
and do not hold any data in
themselves apart from the successive operations themselves to perform the transformations, i.e., the tree of the transformations.
We refer further the cache levels inside the trees.
Key |
Description |
|
|
|
|
|
|---|---|---|---|---|---|---|
|
Explicit feature map of the sample |
|||||
|
Explicit matrix of the sample |
|||||
|
Kernel matrix of the sample |
|||||
|
Explicit Transformation Tree |
|||||
|
Implicit Transformation Tree |
Transformation-Specific
The tree itself contains the transformations in themselves. These values refer to which values are stored inside a
transformation tree instance kerch.transform.TransformTree. The tree can store both the statistics
(average, variance, minimum…) required to perform the transformations and the transformed values themselves
(centered valued, normalized values…). De default refers to the default transformation. We refer to the
documentation of Data and Kernel Transformations for further information.
Key |
Description |
|
|
|
|
|
|---|---|---|---|---|---|---|
|
Transformed sample value of the default transformation |
|||||
|
Transformed sample value of another transformation |
|||||
|
Transformed sample statistics of the default transformation |
|||||
|
Transformed sample statistics of another transformation |
|||||
|
Transformed out-of-sample value of the default transformation |
|||||
|
Transformed out-of-sample value of another transformation |
|||||
|
Transformed out-of-sample statistics of the default transformation |
|||||
|
Transformed out-of-sample statistics of another transformation |
Level-Specific
The output of a level is also saved by default, until the sample or the model parameters change. The default representation
refers to primal or dual. We refer to the documentation of the Level Module for further information.
Many models require an identity matrix to solve the model. This matrix can be stored for further usage, unless the
dimensions change. The different constituents of the loss (regularization term, recontruction term…), referred to as sublosses, are saved
independently from the total loss for monitoring. These are also resetted once the before_step()
method is called.
Key |
Description |
|
|
|
|
|
|---|---|---|---|---|---|---|
|
Output value of the sample in the default representation |
|||||
|
Output value of the sample in the other representation |
|||||
|
Output value of an out-of-sample in the default representation |
|||||
|
Output value of an out-of-sample in the other representation |
|||||
|
Identity matrix in the default representation dimension |
|||||
|
Identity matrix in the other representation dimension |
|||||
|
Individual sublosses in the default representation |
|||||
|
Individual sublosses in the other representation |
Abstract Class
- class kerch.feature.Cache(*args, **kwargs)[source]
Bases:
Module- Parameters:
cache_level (str, optional) – Cache level for saving temporary execution results during the execution. The higher the cache, the more is saved. Defaults to
'normal'. We refer to the Cache Management documentation for further information.logging_level (int, optional) – Logging level for this specific instance. If the value is
None, the current default kerch global log level will be used. Defaults toNone(default kerch logging level). We refer to the Logging in Kerch documentation for further information.
- _apply(fn, recurse=True)[source]
This if the native function by
torch.nn.modules.module.Module, used when porting the module. This ensures that the cache is also ported. This is used for example to port the data to the GPU or CPU.Note
This method is documented for completeness, but it should never be required to call it directly.
- _clean_cache(max_level: str | int | None = None)[source]
Cleans all cache elements above a certain level. This is relevant for cleaning the cache elements that have been forced (see
_save()).- Parameters:
max_level (str | int, optional) – all levels above this level will be cleaned,
max_levelexcluded. Defaults to the default cache level.
Note
This method is documented for completeness, but it should never be required to call it directly.
- _get(key, fun=None, level_key=None, default_level: str = 'normal', force: bool = False, overwrite: bool = False, persisting=False, destroy=False) Any[source]
Retrieves an element from the cache. If the element is not present, it saved to the cache provided its level is lower or equal to the default level. This can be overwritten by the overwrite argument.
- Parameters:
key (str) – key of the cache element.
fun (function handle) – function to compute the element if not in the cache already.
level_key (str, optional) – key referencing the default level to use in
kerch.DEFAULT_CACHE_LEVELS. If not specified, thedefault_levelargument is used.default_level (str, optional) – level where to save the cache element. Defaults to ‘normal’.
force (bool, optional) – if the value is
True, the element will nevertheless be saved whatever level is specified. Defaults toFalse.persisting (bool, optional) – These values are meant to persist after a cache reset when calling
reset()withreset_persisting=False. Defaults toFalse.destroy (bool, optional) – This destroys the value from the cache after being read/computed. This is meant for short-term memory. Defaults to
False.
- Returns:
The result of
fun()
- _remove_from_cache(key: str | List[str]) None[source]
Removes one or more specific element(s) from the cache.
Note
This method is documented for completeness, but it should never be required to call it directly.
- _reset_cache(reset_persisting: bool = True, avoid_classes: list | None = None) None[source]
This just resets the cache and makes it empty.
- Parameters:
Note
This method is documented for completeness, but it should never be required to call it directly.
- _save(key, fun, level_key=None, default_level: str = 'total', force: bool = False, persisting=False) Any[source]
Saves an element in the cache.
- Parameters:
key (str) – key of the cache element.
fun (function handle) – function to compute the element if not in the cache already.
level_key (str, optional) – key referencing the default level to use in
kerch.DEFAULT_CACHE_LEVELS. If not specified, thedefault_levelargument is used.default_level (str, optional) – level where to save the cache element. Defaults to ‘total’.
force (bool, optional) – if the value is
True, the element will nevertheless be saved whatever level is specified. Defaults toFalse.persisting (bool, optional) – These values are meant to persist after a cache reset when calling
reset()withreset_persisting=False. Defaults toFalse.
- Returns:
The result of
fun()
- cache_keys(private: bool = False) Iterable[str][source]
Returns an iterable containing the different cache keys. We refer to the Cache Management documentation for more information.
- Parameters:
private (bool, optional) – Some cache elements are private and are not returned unless set to
True. Defaults toFalse.
- property cache_level: str
Cache level for saving temporary execution results during the execution. The higher the cache, the more is saved. Defaults to
'normal'unless set otherwise during instantiation. The different possible values are:"none": the cache is non-existent and everything is computed on the go."light": the cache is very light. For example, only the kernel matrix and statistics of the sample points are saved."normal": same as light, but the statistics of the out-of-sample points are also saved."heavy": in addition to the statistics, the final kernel matrices of the out-of-sample points are saved."total": every step of any computation is saved.
We refer to the Cache Management documentation for further information.
- print_cache(private: bool = False) None[source]
Prints the cache content. We refer to the Cache Management documentation for further information.
- Parameters:
private (bool, optional) – Some cache elements are private and are not returned unless set to
True. Defaults toFalse.
- reset(recurse=False, reset_persisting=True) None[source]
Resets the cache to be empty. We refer to the Cache Management documentation for more information.
- Parameters:
recurse (bool, optional) – If
True, resets the cache of this module and also of its potential children. otherwise, it only resets the cache for this module. Defaults toTrue.reset_persisting (bool, optional) – Persisting elements are meant to resist to a cache reset (see
_save()). The option allows to also reset them ifTrue. Defaults toTrue.
Examples
For a proper usage of the Kerch package, there is no need to manage the cache. For the sake of completeness, we however provide two examples. The first one shows how the cache works on an existing implementation. The second example shows how one can manage cache elements by itself.
KPCA
The following example illustrates the working of the cache. We will consider two examples of a
kerch.level.KPCA, one with a light cache level and another with a total cache level.
import kerch
import torch
torch.manual_seed(0)
sample = torch.randn(5, 3)
oos = torch.randn(2, 3)
kpca_light_cache = kerch.level.KPCA(sample=sample, # random sample
dim_output=2, # we want an output dimension of 2 (the input is 3)
sample_transform=['min'], # we want the input to be normalized (based on the statistics of the sample)
kernel_transform=['center'], # we want the kernel to be center
cache_level='light') # a 'light' cache level (only related to the sample)
kpca_total_cache = kerch.level.KPCA(sample=sample, # idem
dim_output=2, # idem
sample_transform=['min'], # idem
kernel_transform=['center'], # idem
cache_level='total') # a 'total' cache level (saves everything)
If we plot the cache now, nothing is printed: the cache is empty. The advantage of this package is that it does not perform any unnecessary computations. Depending of what is required, it will compute what is strictly necessary.
kpca_light_cache.print_cache()
kernel_implicit_transform [none] : Transforms:
Mean centering (default)
KERNEL_IMPLICIT_TRANSFORM:
After solving the model and passing an out-of-sample through the model, we can see that the cache is pretty much
loaded. Nothing however has been saved on related to the out-of-sample, even if it has been computed. This is a
consequence of light cache level of the module. Similarly, nor has the original non-centered kernel matrix, nor has
the original non-transformed sample been saved.
kpca_light_cache.solve()
kpca_light_cache.forward(oos)
kpca_light_cache.print_cache()
sample_transform [none] : Transforms:
Minimum Centering (default)
K [light] : tensor([[ 0.7424, 0.0142, -0.2422, -0.2459, -0.2685],
[ 0.0142, 0.6743, -0.2509, -0.1506, -0.2870],
[-0.2422, -0.2509, 0.6180, -0.1665, 0.0415],
[-0.2459, -0.1506, -0.1665, 0.7539, -0.1910],
[-0.2685, -0.2870, 0.0415, -0.1910, 0.7049]])
kernel_implicit_transform [none] : Transforms:
Mean centering (default)
SAMPLE_TRANSFORM:
Minimum Centering statistics_sample [light] : tensor([-0.8567, -1.0845, -2.1788])
Minimum Centering data_sample [light] : tensor([[2.3977, 0.7911, 0.0000],
[1.4251, 0.0000, 0.7802],
[1.2600, 1.9225, 1.4595],
[0.4533, 0.4879, 2.3608],
[0.0000, 2.1851, 1.1076]])
KERNEL_IMPLICIT_TRANSFORM:
Mean centering statistics_sample [light] : (tensor([[0.2794],
[0.3135],
[0.3416],
[0.2737],
[0.2982]]), tensor(0.3013))
We can optionally reset the cache. Nothing is printed: the cache is empty again.
kpca_light_cache.reset()
kpca_light_cache.print_cache()
We can now have a look at the 'total' version. Again, before anything is required, the cache remains empty: nothing
has computed yet.
kpca_total_cache.print_cache()
kernel_implicit_transform [none] : Transforms:
Mean centering (default)
KERNEL_IMPLICIT_TRANSFORM:
We now see that everything is saved.
kpca_total_cache.solve()
kpca_total_cache.forward(oos)
kpca_total_cache.print_cache()
sample_transform [none] : Transforms:
Minimum Centering (default)
K [light] : tensor([[ 0.7424, 0.0142, -0.2422, -0.2459, -0.2685],
[ 0.0142, 0.6743, -0.2509, -0.1506, -0.2870],
[-0.2422, -0.2509, 0.6180, -0.1665, 0.0415],
[-0.2459, -0.1506, -0.1665, 0.7539, -0.1910],
[-0.2685, -0.2870, 0.0415, -0.1910, 0.7049]])
kernel_implicit_transform [none] : Transforms:
Mean centering (default)
forward_140514613185136_dual [normal] : tensor([[-0.0591, 0.6962],
[-0.0009, 0.4885]])
SAMPLE_TRANSFORM:
base data_sample [normal] : tensor([[ 1.5410, -0.2934, -2.1788],
[ 0.5684, -1.0845, -1.3986],
[ 0.4033, 0.8380, -0.7193],
[-0.4033, -0.5966, 0.1820],
[-0.8567, 1.1006, -1.0712]])
base data_oos_140514613185136 [normal] : tensor([[ 0.1227, -0.5663, 0.3731],
[-0.8920, -1.5091, 0.3704]])
Minimum Centering statistics_sample [light] : tensor([-0.8567, -1.0845, -2.1788])
Minimum Centering data_sample [light] : tensor([[2.3977, 0.7911, 0.0000],
[1.4251, 0.0000, 0.7802],
[1.2600, 1.9225, 1.4595],
[0.4533, 0.4879, 2.3608],
[0.0000, 2.1851, 1.1076]])
Minimum Centering statistics_oos_140514613185136 [normal] : tensor([-0.8567, -1.0845, -2.1788])
Minimum Centering data_oos_140514613185136 [normal] : tensor([[ 0.9794, 0.5182, 2.5519],
[-0.0353, -0.4246, 2.5492]])
KERNEL_IMPLICIT_TRANSFORM:
base data_oos_140514611875280 [normal] : tensor([[9.3538e-03, 1.4093e-01, 1.7157e-01, 8.4308e-01, 4.2237e-02],
[5.2590e-04, 5.1963e-02, 1.0564e-02, 5.4803e-01, 7.9825e-03]])
base data_sample [normal] : tensor([[1.0000, 0.3058, 0.0776, 0.0059, 0.0079],
[0.3058, 1.0000, 0.1029, 0.1353, 0.0234],
[0.0776, 0.1029, 1.0000, 0.1476, 0.3801],
[0.0059, 0.1353, 0.1476, 1.0000, 0.0796],
[0.0079, 0.0234, 0.3801, 0.0796, 1.0000]])
Mean centering statistics_sample [light] : (tensor([[0.2794],
[0.3135],
[0.3416],
[0.2737],
[0.2982]]), tensor(0.3013))
Mean centering statistics_oos_140514611875280 [normal] : (tensor([[0.2414],
[0.1238]]), tensor(0.3013))
Mean centering data_oos_140514611875280 [normal] : tensor([[-0.2102, -0.1127, -0.1102, 0.6292, -0.1961],
[-0.1014, -0.0841, -0.1536, 0.4518, -0.1127]])
We can optionally reset the cache again. Nothing is printed: the cache is empty again.
kpca_total_cache.reset()
kpca_total_cache.print_cache()
Managing the Cache
In the following example, we show how we can add an element to the cache and recover it when called.
import kerch
import torch
import time
class MyCacheExample(kerch.feature.Cache):
def __init__(self, *args, **kwargs):
super(MyCacheExample, self).__init__(*args, **kwargs)
self.big_matrix = kwargs.pop('big_matrix')
def _compute_qr(self):
def qr_fun():
return torch.linalg.qr(self.big_matrix)
return self._get(key='qr', fun=qr_fun)
@property
def Q(self):
q, r = self._compute_qr()
return q
@property
def R(self):
q, r = self._compute_qr()
return r
# we instantiate our new class
m = torch.randn(200, 100)
my_example = MyCacheExample(big_matrix=m)
# we time our Q property
start = time.time()
my_example.Q
end = time.time()
print('First access: ' + str(end-start), end='\n\n')
# we time it again
start = time.time()
my_example.Q
end = time.time()
print('Second access: ' + str(end-start), end='\n\n')
# we now have a look at our cache
my_example.print_cache()
First access: 0.0011394023895263672
Second access: 5.0067901611328125e-06
qr [normal] : torch.return_types.linalg_qr(
Q=tensor([[-2.2830e-02, 8.1899e-03, 4.3683e-02, ..., 1.6254e-01,
-7.8151e-02, -7.4254e-02],
[ 8.8169e-02, -1.3530e-03, -3.0335e-02, ..., 2.4953e-02,
-8.5254e-02, 2.3078e-02],
[ 4.5048e-03, 9.3206e-02, -6.3898e-02, ..., 4.8213e-02,
2.2990e-02, -5.4721e-02],
...,
[-4.9373e-02, 7.2460e-03, -8.7388e-02, ..., 1.2602e-02,
5.2047e-02, -2.1568e-01],
[-6.8546e-02, -1.3272e-01, 7.9209e-02, ..., 4.2796e-03,
-4.8783e-02, -2.0896e-02],
[-1.0969e-02, 8.1132e-02, 7.9921e-02, ..., -3.9294e-02,
-5.0511e-05, 1.1770e-02]]),
R=tensor([[-14.0207, -0.7918, 1.0278, ..., 1.2658, -0.4110, -1.0155],
[ 0.0000, -15.4406, 0.9448, ..., 0.5908, -0.5970, 0.3828],
[ 0.0000, 0.0000, -14.0512, ..., 1.3263, -1.0769, -0.2321],
...,
[ 0.0000, 0.0000, 0.0000, ..., 9.7327, -0.5826, 0.8639],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 9.3808, -0.2515],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, -10.0514]]))