import torch
from colibri.optics.functional import psf_single_doe_spectral, convolutional_sensing, wiener_filter, ideal_panchromatic_sensor
from colibri.optics.sota_does import conventional_lens, nbk7_refractive_index
from .utils import BaseOpticsLayer
[docs]
class SingleDOESpectral(BaseOpticsLayer):
r"""
Single Diffractive Optical Element for Spectral Imaging
This optical system allow for the capture of spatio-spectral information through a wavelength dependent phase coding of light through a diffractive optical element.
Mathematically, this system can be described as follows:
.. math::
\mathbf{y} = \forwardLinear_{\learnedOptics}(\mathbf{x}) + \noise
where :math:`\noise` is the sensor noise, :math:`\mathbf{x}\in\xset` is the input optical field, :math:`\mathbf{y}\in\yset` are the acquired signal, for SingleDOESpectral, :math:`\xset = \mathbb{R}^{L \times M \times N}` and :math:`\yset = \mathbb{R}^{M \times N}`, and :math:`\forwardLinear_{\learnedOptics}:\xset\rightarrow \yset` is the forward operator of the diffractive optical element, such as
.. math::
\begin{align*}
\forwardLinear_{\learnedOptics}: \mathbf{x} &\mapsto \mathbf{y} \\
\mathbf{y}_l &= \sum_{l=1}^{L} \psf(\learnedOptics)_l * \mathbf{x}_{, l}
\end{align*}
with :math:`\psf(\learnedOptics)_l` the optical-point spread function (PSF) of the diffractive optical element at wavelength :math:`\lambda_l`, calculated using :func:`colibri.optics.functional.psf_single_doe_spectral`
"""
def __init__(self, input_shape,
height_map = None,
aperture = None,
pixel_size = 1e-6,
wavelengths = torch.tensor([450, 550, 650])*1e-9,
source_distance = 3,
sensor_distance = 50e-3,
sensor_spectral_sensitivity=ideal_panchromatic_sensor,
doe_refractive_index = None,
approximation = "fresnel",
domain = "fourier",
trainable = False):
r"""
Initializes the SingleDOESpectral layer.
Args:
input_shape (tuple): The shape of the input spectral image in the format (L, M, N), where L is the number of spectral bands, M is the height, and N is the width.
height_map (torch.Tensor): The height map of the DOE (Diffractive Optical Element). If None, a default conventional lens will be generated.
aperture (float): The aperture of the DOE. If None, a circular aperture will be generated.
wavelengths (torch.Tensor): The list of wavelengths corresponding to the spectral bands, by default torch.tensor([450, 550, 650])*1e-9.
source_distance (float): The distance between the source and the DOE, by default is 3 meters.
sensor_distance (float): The distance between the DOE and the sensor, by default is 50e-3 meters.
pixel_size (float): The size of a pixel in the sensor.
sensor_spectral_sensitivity (torch.Tensor, optional): The spectral sensitivity of the sensor. Defaults to ideal_panchromatic_sensor.
doe_refractive_index (float, optional): The refractive index of the DOE. Defaults to None.
approximation (str, optional): The approximation used to calculate the PSF. It can be "fresnel", "angular_spectrum" or "fraunhofer". Defaults to "fresnel".
domain (str, optional): The domain used to calculate the PSF. It can be "fourier" or "spatial". Defaults to "fourier".
trainable (bool, optional): Whether the height map of the DOE is trainable or not. Defaults to False.
"""
self.trainable = trainable
self.aperture = aperture
self.sensor_spectral_sensitivity = sensor_spectral_sensitivity
self.wavelengths = wavelengths
self.source_distance = source_distance
self.sensor_distance = sensor_distance
self.pixel_size = pixel_size
self.approximation = approximation
self.domain = domain
self.refractive_index = nbk7_refractive_index if doe_refractive_index is None else doe_refractive_index
self.L, self.M, self.N = input_shape # Extract spectral image shape
if height_map is None:
height_map, _ = conventional_lens(M=self.M, N=self.N, focal=50e-3, radius=1)
if aperture is None:
_, aperture = conventional_lens(M=self.M, N=self.N, focal=50e-3, radius=1)
#Add parameter CA in pytorch manner
height_map = torch.nn.Parameter(height_map, requires_grad=self.trainable)
super(SingleDOESpectral, self).__init__(learnable_optics=height_map, sensing=self.convolution, backward=self.deconvolution)
def get_psf(self, height_map=None):
if height_map is None:
height_map = self.learnable_optics
return psf_single_doe_spectral(height_map=height_map,
aperture=self.aperture,
wavelengths=self.wavelengths,
source_distance=self.source_distance,
sensor_distance=self.sensor_distance,
pixel_size=self.pixel_size,
refractive_index=self.refractive_index,
approximation=self.approximation)
[docs]
def convolution(self, x, height_map):
r"""
Forward operator of the SingleDOESpectral layer.
Args:
x (torch.Tensor): Input tensor with shape (B, L, M, N)
Returns:
torch.Tensor: Output tensor with shape (B, 1, M, N)
"""
psf = self.get_psf(height_map)
field = convolutional_sensing(x, psf, domain=self.domain)
return self.sensor_spectral_sensitivity(field)
[docs]
def deconvolution(self, x, height_map, alpha=1e-3):
r"""
Backward operator of the SingleDOESpectral layer.
Args:
x (torch.Tensor): Input tensor with shape (B, 1, M, N)
Returns:
torch.Tensor: Output tensor with shape (B, L, M, N)
"""
psf = self.get_psf(height_map)
x = self.sensor_spectral_sensitivity(x)
return wiener_filter(x, psf, alpha)
[docs]
def forward(self, x, type_calculation="forward"):
r"""
Performs the forward or backward operator according to the type_calculation
Args:
x (torch.Tensor): Input tensor with shape (B, L, M, N)
type_calculation (str): String, it can be "forward", "backward" or "forward_backward"
Returns:
torch.Tensor: Output tensor with shape (B, L, M, N)
Raises:
ValueError: If type_calculation is not "forward", "backward" or "forward_backward"
"""
return super(SingleDOESpectral, self).forward(x, type_calculation)