Source code for kerch.utils.type

# coding=utf-8
import torch
import logging

FTYPE = torch.float32
ITYPE = torch.int32

[docs] def gpu_available() -> bool: r""" Returns whether GPU-enhanced computation is possible and configured on this machine. """ import torch.cuda if torch.cuda.is_available(): from ..feature.logger import _GLOBAL_LOGGER _GLOBAL_LOGGER._logger.info("Using CUDA version " + torch.version.cuda) return True return False
[docs] def set_ftype(type): r""" Sets the generic floating type :attr:`kerch.FTYPE` used throughout the package. Typical choices include half precision :attr:`torch.float16`, single precision :attr:`torch.float32` (default) and double precision :attr:`torch.float64`. :param type: Default floating type to be used. :type type: PyTorch type .. warning: This does not affect the already instantiated tensors. It is thus preferable to set this in the beginning of the code to avoid any type mismatch. """ assert isinstance(type, torch.dtype), 'The type is not an instance of torch.dtype.' global FTYPE FTYPE = type logging.warning('Changing name has to be carefully considered as changes ' 'after initialization may lead to inconsistencies.')
[docs] def set_itype(type): r""" Sets the generic integer type :attr:`kerch.ITYPE` used throughout the package. Typical choices include short integers :attr:`torch.int16` (-32 768 to 32 767), classical integers :attr:`torch.int32` (-2^31-1 to 2^31, default) and long integers :attr:`torch.int64` (-2^63-1 to 2^63). We do not advise on using unsigned integers because of their limited support in PyTorch. :param type: Default integer type to be used. :type type: PyTorch type .. warning: This does not affect the already instantiated tensors. It is thus preferable to set this in the beginning of the code to avoid any type mismatch. """ assert isinstance(type, torch.dtype), 'The type is not an instance of torch.dtype.' global ITYPE ITYPE = type logging.warning('Changing name has to be carefully considered as changes ' 'after initialization may lead to inconsistencies.')
[docs] def set_eps(eps: float): r""" Sets the generic epsilon value used throughout the toolbox to guarantee stability. :param eps: Default epsilon type to be used. :type eps: float .. warning: It is preferable to set this in the beginning of the code to avoid any type mismatch, preferably after setting the data type. """ assert eps>=0, 'The EPS value has to be positive' global EPS EPS = torch.tensor(eps, dtype=FTYPE) return EPS
EPS = set_eps(1.e-7)