# coding=utf-8
"""
File containing the implicit kernel class.
@author: HENRI DE PLAEN
@copyright: KU LEUVEN
@license: MIT
@date: March 2021
"""
from typing import Iterator
from ... import utils
from ..implicit import Implicit, Kernel
import torch
[docs]
@utils.extend_docstring(Kernel)
class ImplicitNN(Implicit):
r"""
Implicit kernel class, parametrized by a neural network.
.. math::
k(x,y) = NN\left( [x, y] \right).
.. warning::
This kernel is not positive semi-definite in the general case. This is only possible if a specific choice of
neural network is provided.
:param network: Network to be used.
:type network: torch.nn.Module
"""
@utils.kwargs_decorator(
{"network": None})
def __init__(self, *args, **kwargs):
"""
:param network: torch.nn.Module explicit kernel
"""
super(ImplicitNN, self).__init__(*args, **kwargs)
self._network: torch.nn.Module = kwargs["network"]
assert isinstance(self._network, torch.nn.Module), "Torch network level must be specified."
def __str__(self):
return "implicit kernel"
@property
def hparams_fixed(self):
return {"Kernel": "Implicit NN", **super(ImplicitNN, self).hparams_fixed}
def _implicit(self, x, y):
raise NotImplementedError
# x, y = super(ImplicitKernel, self)._implicit(x, y)
# return self._encoder(x, y)
def _euclidean_parameters(self, recurse=True) -> Iterator[torch.nn.Parameter]:
yield from self._network.parameters()
yield from super(ImplicitNN, self)._euclidean_parameters(recurse)