Source code for colibri.optics.doe

import torch
from colibri.optics.functional import psf_single_doe_spectral, convolutional_sensing, wiener_filter, ideal_panchromatic_sensor
from colibri.optics.sota_does import conventional_lens, nbk7_refractive_index
from .utils import BaseOpticsLayer

[docs] class SingleDOESpectral(BaseOpticsLayer): r""" Single Diffractive Optical Element for Spectral Imaging This optical system allow for the capture of spatio-spectral information through a wavelength dependent phase coding of light through a diffractive optical element. Mathematically, this system can be described as follows: .. math:: \mathbf{y} = \forwardLinear_{\learnedOptics}(\mathbf{x}) + \noise where :math:`\noise` is the sensor noise, :math:`\mathbf{x}\in\xset` is the input optical field, :math:`\mathbf{y}\in\yset` are the acquired signal, for SingleDOESpectral, :math:`\xset = \mathbb{R}^{L \times M \times N}` and :math:`\yset = \mathbb{R}^{M \times N}`, and :math:`\forwardLinear_{\learnedOptics}:\xset\rightarrow \yset` is the forward operator of the diffractive optical element, such as .. math:: \begin{align*} \forwardLinear_{\learnedOptics}: \mathbf{x} &\mapsto \mathbf{y} \\ \mathbf{y}_l &= \sum_{l=1}^{L} \psf(\learnedOptics)_l * \mathbf{x}_{, l} \end{align*} with :math:`\psf(\learnedOptics)_l` the optical-point spread function (PSF) of the diffractive optical element at wavelength :math:`\lambda_l`, calculated using :func:`colibri.optics.functional.psf_single_doe_spectral` """ def __init__(self, input_shape, height_map = None, aperture = None, pixel_size = 1e-6, wavelengths = torch.tensor([450, 550, 650])*1e-9, source_distance = 3, sensor_distance = 50e-3, sensor_spectral_sensitivity=ideal_panchromatic_sensor, doe_refractive_index = None, approximation = "fresnel", domain = "fourier", trainable = False): r""" Initializes the SingleDOESpectral layer. Args: input_shape (tuple): The shape of the input spectral image in the format (L, M, N), where L is the number of spectral bands, M is the height, and N is the width. height_map (torch.Tensor): The height map of the DOE (Diffractive Optical Element). If None, a default conventional lens will be generated. aperture (float): The aperture of the DOE. If None, a circular aperture will be generated. wavelengths (torch.Tensor): The list of wavelengths corresponding to the spectral bands, by default torch.tensor([450, 550, 650])*1e-9. source_distance (float): The distance between the source and the DOE, by default is 3 meters. sensor_distance (float): The distance between the DOE and the sensor, by default is 50e-3 meters. pixel_size (float): The size of a pixel in the sensor. sensor_spectral_sensitivity (torch.Tensor, optional): The spectral sensitivity of the sensor. Defaults to ideal_panchromatic_sensor. doe_refractive_index (float, optional): The refractive index of the DOE. Defaults to None. approximation (str, optional): The approximation used to calculate the PSF. It can be "fresnel", "angular_spectrum" or "fraunhofer". Defaults to "fresnel". domain (str, optional): The domain used to calculate the PSF. It can be "fourier" or "spatial". Defaults to "fourier". trainable (bool, optional): Whether the height map of the DOE is trainable or not. Defaults to False. """ self.trainable = trainable self.aperture = aperture self.sensor_spectral_sensitivity = sensor_spectral_sensitivity self.wavelengths = wavelengths self.source_distance = source_distance self.sensor_distance = sensor_distance self.pixel_size = pixel_size self.approximation = approximation self.domain = domain self.refractive_index = nbk7_refractive_index if doe_refractive_index is None else doe_refractive_index self.L, self.M, self.N = input_shape # Extract spectral image shape if height_map is None: height_map, _ = conventional_lens(M=self.M, N=self.N, focal=50e-3, radius=1) if aperture is None: _, aperture = conventional_lens(M=self.M, N=self.N, focal=50e-3, radius=1) #Add parameter CA in pytorch manner height_map = torch.nn.Parameter(height_map, requires_grad=self.trainable) super(SingleDOESpectral, self).__init__(learnable_optics=height_map, sensing=self.convolution, backward=self.deconvolution) def get_psf(self, height_map=None): if height_map is None: height_map = self.learnable_optics return psf_single_doe_spectral(height_map=height_map, aperture=self.aperture, wavelengths=self.wavelengths, source_distance=self.source_distance, sensor_distance=self.sensor_distance, pixel_size=self.pixel_size, refractive_index=self.refractive_index, approximation=self.approximation)
[docs] def convolution(self, x, height_map): r""" Forward operator of the SingleDOESpectral layer. Args: x (torch.Tensor): Input tensor with shape (B, L, M, N) Returns: torch.Tensor: Output tensor with shape (B, 1, M, N) """ psf = self.get_psf(height_map) field = convolutional_sensing(x, psf, domain=self.domain) return self.sensor_spectral_sensitivity(field)
[docs] def deconvolution(self, x, height_map, alpha=1e-3): r""" Backward operator of the SingleDOESpectral layer. Args: x (torch.Tensor): Input tensor with shape (B, 1, M, N) Returns: torch.Tensor: Output tensor with shape (B, L, M, N) """ psf = self.get_psf(height_map) x = self.sensor_spectral_sensitivity(x) return wiener_filter(x, psf, alpha)
[docs] def forward(self, x, type_calculation="forward"): r""" Performs the forward or backward operator according to the type_calculation Args: x (torch.Tensor): Input tensor with shape (B, L, M, N) type_calculation (str): String, it can be "forward", "backward" or "forward_backward" Returns: torch.Tensor: Output tensor with shape (B, L, M, N) Raises: ValueError: If type_calculation is not "forward", "backward" or "forward_backward" """ return super(SingleDOESpectral, self).forward(x, type_calculation)