Source code for kerch.utils.cast

# coding=utf-8
from __future__ import annotations
import torch
from .type import ITYPE, FTYPE
from .errors import RepresentationError

[docs] def castf(x, dev=None, tensor=True) -> torch.Tensor | None: r""" Casts the input to a PyTorch float tensor. If the input is a scalar, it is cast to a tensor. The cast can be done to a 1D or a 2D tensor depending on the parameter `tensor`. If the provided data x has more than 2 dimensions, an error is raised. The default floating type used is :attr:`kerch.FTYPE`. The `None` values are not casted and returned as is. :param x: The input to cast. :param dev: The device to cast the tensor to. Defaults to `None`, which corresponds to no device change. :param tensor: If True, the input is cast to a 2D tensor. If False, the input is cast to a 1D tensor. :return: The input cast to a PyTorch float tensor, with optional device. :type x: float | torch.Tensor | np.ndarray | None :type dev: Optional[torch.device] :type tensor: bool :rtype: torch.Tensor | None """ if x is None: return None if not torch.is_tensor(x): x = torch.tensor(x, requires_grad=False, dtype=FTYPE, device=dev) else: x = x.type(FTYPE) if dev is not None: x = x.to(dev) if tensor: dim = len(x.shape) if dim == 0: x = x.unsqueeze(0) dim = 1 if dim == 1: x = x.unsqueeze(1) elif dim > 2: raise NameError(f"Provided data has too much dimensions ({dim}).") return x
[docs] def casti(x, dev=None, tensor=False) -> torch.Tensor | None: r""" Casts the input to a PyTorch integer tensor. If the input is a scalar, it is cast to a tensor. The cast can be done to a 1D or a 2D tensor depending on the parameter `tensor`. If the provided data x has more than 2 dimensions, an error is raised. The default floating type used is :attr:`kerch.ITYPE`. The `None` values are not casted and returned as is. :param x: The input to cast. :param dev: The device to cast the tensor to. Defaults to `None`, which corresponds to no device change. :param tensor: If True, the input is cast to a 2D tensor. If False, the input is cast to a 1D tensor. :return: The input cast to a PyTorch integer tensor, with optional device. :type x: int | torch.Tensor | np.ndarray | None :type dev: Optional[torch.device] :type tensor: bool :rtype: torch.Tensor | None """ if x is None: return None if not torch.is_tensor(x): x = torch.tensor(x, requires_grad=False, dtype=ITYPE, device=dev) else: x = x.type(ITYPE) if dev is not None: x = x.to(dev) if tensor: dim = len(x.shape) if dim == 0: x = x.unsqueeze(0) dim = 1 if dim == 1: x = x.unsqueeze(1) elif dim > 2: raise NameError(f"Provided data has too much dimensions ({dim}).") return x.squeeze()
[docs] def check_representation(representation: str = None, default: str = None, cls=None) -> str: r""" This model verifies whether the provided representation is valid. If the representation is `None` and a default value is provided, the default value is returned. If the representation is not `None` and is not valid, an error is raised. The valid representations are `primal` and `dual`. :param representation: The representation to check. :param default: Default representation for the case where `representation` is `None`. :param cls: An instance of :class:`kerch.feature.Logger` to throw the error from, typically the one calling this method. This is optional. :return: "primal" | "dual" :type representation: str, optional :type default: str, optional :type cls: kerch.feature.Logger, optional :rtype: str """ if representation is None and default is not None: representation = default valid = ["primal", "dual"] if representation not in valid: raise RepresentationError(cls) return representation
[docs] def capitalize_only_first(val: str) -> str: r""" This method returns the input string with the first letter capitalized and the rest of the string unchanged. :param val: String to be capitalized. :return: Capitalized string. :type val: str :rtype: str """ return val[0].upper() + val[1:]