Source code for colibri.optics.modulo

import torch

from colibri.optics.functional import modulo

from .utils import BaseOpticsLayer




[docs] class Modulo(BaseOpticsLayer): r""" Modulo Sensing Operator This operator applies the non-linear modulo operation to the input tensor. Mathematically, this operator can be described as follows: .. math:: \mathbf{y} = \mathcal{M}_t(\mathbf{x}) = \text{mod}(\mathbf{x}, t) where :math:`\mathbf{x}\in\xset` is the input optical field, :math:`\mathbf{y}\in\yset` is the acquired signal, and :math:`t` is the threshold value. """ def __init__(self, threshold=1.0): self.threshold = torch.tensor([threshold], requires_grad=False) super(Modulo, self).__init__(learnable_optics=threshold, sensing=modulo, backward=None)
[docs] def forward(self, x, type_calculation="forward"): r""" Forward pass of the modulo operator Args: x (torch.Tensor): Input tensor with shape (B, L, M, N) type_calculation (str): Type of calculation to perform. Default: "forward" Returns: torch.Tensor: Output tensor with shape (B, L, M, N) """ if type_calculation == "backward": raise ValueError("The modulo operator does not have a backward pass") return super(Modulo, self).forward(x, type_calculation)