Source code for colibri.optics.functional

import torch
import numpy as np

[docs] def prism_operator(x:torch.Tensor, shift_sign:int = 1) -> torch.Tensor: r""" Prism operator, shifts linearly the input tensor x in the spectral dimension. Args: x (torch.Tensor): Input tensor with shape (B, L, M, N) shift_sign (int): Integer, it can be 1 or -1, it indicates the direction of the shift if 1 the shift is to the right, if -1 the shift is to the left Returns: torch.Tensor: Output tensor with shape (1, L, M, N+L-1) if shift_sign is 1, or (1, L, M, N-L+1) if shift_sign is -1 """ assert shift_sign == 1 or shift_sign == -1, "The shift sign must be 1 or -1" _, L, M, N = x.shape # Extract spectral image shape x = torch.unbind(x, dim=1) if shift_sign == 1: # Shifting produced by the prism x = [torch.nn.functional.pad(x[l], (l, L - l - 1)) for l in range(L)] else: # Unshifting produced by the prism x = [x[l][:, :, l:N - (L - 1) + l] for l in range(L)] x = torch.stack(x, dim=1) return x
[docs] def forward_color_cassi(x: torch.Tensor, ca: torch.Tensor)-> torch.Tensor: r""" Forward operator of color coded aperture snapshot spectral imager (Color-CASSI) For more information refer to: Colored Coded Aperture Design by Concentration of Measure in Compressive Spectral Imaging https://doi.org/10.1109/TIP.2014.2310125 Args: x (torch.Tensor): Spectral image with shape (B, L, M, N) ca (torch.Tensor): Coded aperture with shape (1, L, M, N) Returns: torch.Tensor: Measurement with shape (B, 1, M, N + L - 1) """ y = torch.multiply(x, ca) y = prism_operator(y, shift_sign=1) return y.sum(dim=1, keepdim=True)
[docs] def backward_color_cassi(y: torch.Tensor, ca: torch.Tensor)-> torch.Tensor: r""" Backward operator of color coded aperture snapshot spectral imager (Color-CASSI) For more information refer to: Colored Coded Aperture Design by Concentration of Measure in Compressive Spectral Imaging https://doi.org/10.1109/TIP.2014.2310125 Args: y (torch.Tensor): Measurement with shape (B, 1, M, N + L - 1) ca (torch.Tensor): Coded aperture with shape (1, L, M, N) Returns: torch.Tensor: Spectral image with shape (B, L, M, N) """ y = torch.tile(y, [1, ca.shape[1], 1, 1]) y = prism_operator(y, shift_sign=-1) x = torch.multiply(y, ca) return x
[docs] def forward_dd_cassi(x: torch.Tensor, ca: torch.Tensor)-> torch.Tensor: r""" Forward operator of dual disperser coded aperture snapshot spectral imager (DD-CASSI) For more information refer to: Single-shot compressive spectral imaging with a dual-disperser architecture https://doi.org/10.1364/OE.15.014013 Args: x (torch.Tensor): Spectral image with shape (B, L, M, N) ca (torch.Tensor): Coded aperture with shape (1, 1, M, N + L - 1) Returns: torch.Tensor: Measurement with shape (B, 1, M, N) """ _, L, M, N = x.shape # Extract spectral image shape assert ca.shape[-1] == N + L - 1, "The coded aperture must have the same size as a dispersed scene" ca = torch.tile(ca, [1, L, 1, 1]) ca = prism_operator(ca, shift_sign=-1) y = torch.multiply(x, ca) return y.sum(dim=1, keepdim=True)
[docs] def backward_dd_cassi(y: torch.Tensor, ca: torch.Tensor)-> torch.Tensor: r""" Backward operator of dual disperser coded aperture snapshot spectral imager (DD-CASSI) For more information refer to: Single-shot compressive spectral imaging with a dual-disperser architecture https://doi.org/10.1364/OE.15.014013 Args: y (torch.Tensor): Measurement with shape (B, 1, M, N) ca (torch.Tensor): Coded aperture with shape (1, 1, M, N + L - 1) Returns: torch.Tensor: Spectral image with shape (1, L, M, N) """ _, _, M, N_hat = ca.shape # Extract spectral image shape L = N_hat - M + 1 # Number of shifts y = torch.tile(y, [1, L, 1, 1]) ca = torch.tile(ca, [1, L, 1, 1]) ca = prism_operator(ca, shift_sign=-1) return torch.multiply(y, ca)
[docs] def forward_sd_cassi(x: torch.Tensor, ca: torch.Tensor)-> torch.Tensor: r""" Forward operator of single disperser coded aperture snapshot spectral imager (SD-CASSI) For more information refer to: Compressive Coded Aperture Spectral Imaging: An Introduction: https://doi.org/10.1109/MSP.2013.2278763 Args: x (torch.Tensor): Spectral image with shape (B, L, M, N) ca (torch.Tensor): Coded aperture with shape (1, 1, M, N) Returns: torch.Tensor: Measurement with shape (B, 1, M, N + L - 1) """ y1 = torch.multiply(x, ca) # Multiplication of the scene by the coded aperture _, M, N, L = y1.shape # Extract spectral image shape # shift and sum y2 = prism_operator(y1, shift_sign=1) return y2.sum(dim=1, keepdim=True)
[docs] def backward_sd_cassi(y: torch.Tensor, ca: torch.Tensor) -> torch.Tensor: r""" Backward operator of single disperser coded aperture snapshot spectral imager (SD-CASSI) For more information refer to: Compressive Coded Aperture Spectral Imaging: An Introduction: https://doi.org/10.1109/MSP.2013.2278763 Args: y (torch.Tensor): Measurement with shape (B, 1, M, N + L - 1) ca (torch.Tensor): Coded aperture with shape (1, 1, M, N) Returns: torch.Tensor: Spectral image with shape (B, L, M, N) """ _, _, M, N = y.shape # Extract spectral image shape L = N - M + 1 # Number of shifts y = torch.tile(y, [1, L, 1, 1]) y = prism_operator(y, shift_sign=-1) return torch.multiply(y, ca)
[docs] def forward_spc(x: torch.Tensor, H: torch.Tensor) -> torch.Tensor: r""" Forward propagation through the Single Pixel Camera (SPC) model. For more information refer to: Optimized Sensing Matrix for Single Pixel Multi-Resolution Compressive Spectral Imaging 10.1109/TIP.2020.2971150 Args: x (torch.Tensor): Input image tensor of size (B, L, M, N). H (torch.Tensor): Measurement matrix of size (S, M*N). Returns: torch.Tensor: Output measurement tensor of size (B, S, L). """ B, L, M, N = x.size() x = x.contiguous().view(B, L, M * N) x = x.permute(0, 2, 1) # measurement H = H.unsqueeze(0).repeat(B, 1, 1) y = torch.bmm(H, x) return y
[docs] def backward_spc(y: torch.Tensor, H: torch.Tensor, pinv=False) -> torch.Tensor: r""" Inverse operation to reconstruct the image from measurements. For more information refer to: Optimized Sensing Matrix for Single Pixel Multi-Resolution Compressive Spectral Imaging 10.1109/TIP.2020.2971150 Args: y (torch.Tensor): Measurement tensor of size (B, S, L). H (torch.Tensor): Measurement matrix of size (S, M*N). pinv (bool): Boolean, if True the pseudo-inverse of H is used, otherwise the transpose of H is used, defaults to False. Returns: torch.Tensor: Reconstructed image tensor of size (B, L, M, N). """ Hinv = torch.pinverse(H) if pinv else torch.transpose(H, 0, 1) Hinv = Hinv.unsqueeze(0).repeat(y.shape[0], 1, 1) x = torch.bmm(Hinv, y) x = x.permute(0, 2, 1) b, c, hw = x.size() h = int(np.sqrt(hw)) x = x.reshape(b, c, h, h) return x
### Wave optics
[docs] def get_spatial_coords(M: int, N: int, pixel_size: float, device=torch.device('cpu'), type='cartesian') -> tuple: r""" Generate the spatial coordinates for wave optics simulations in a specific coordinate system. .. note:: * if type is 'cartesian', we generate :math:`(x, y)` coordinates in the Cartesian coordinate system, where * :math:`x \in \bigg[-\frac{\Delta\cdot N}{2}, \frac{\Delta\cdot N}{2} \bigg]` * :math:`y \in \bigg[-\frac{\Delta\cdot M}{2}, \frac{\Delta\cdot M}{2} \bigg]` * if type is 'polar', we generate :math:`(r, \theta)` coordinates in the Polar coordinate system, where * :math:`r \in \Bigg[0, \sqrt{\bigg(\frac{\Delta\cdot N}{2}\bigg)^2 + \bigg(\frac{\Delta\cdot M}{2}\bigg)^2} \Bigg]` * :math:`\theta \in [-\pi, \pi]` with :math:`\Delta` being the pixel size, :math:`M` the number of pixels in the y axis, and :math:`N` the number of pixels in the x axis. Args: M (int): number of pixels in the y axis. N (int): number of pixels in the x axis. pixel_size (float): Pixel size in meters. device (torch.device): Device, for more see torch.device(). type (str): Type of coordinate system to generate ('cartesian' or 'polar'). Returns: tuple[torch.Tensor, torch.Tensor]: A tuple of tensors representing the X and Y coordinates if 'cartesian', or radius (r) and angle (theta) if 'polar'. """ x = (torch.linspace(-pixel_size*N/2, pixel_size*N/2, N)).to(device=device) y = (torch.linspace(-pixel_size*M/2, pixel_size*M/2, M)).to(device=device) x,y = torch.meshgrid(y, x, indexing='ij') if type=="polar": r = torch.sqrt(x**2 + y**2) theta = torch.atan2(y, x) return r, theta else: return x, y
[docs] def wave_number(wavelength: torch.Tensor): r""" Wavenumber .. math:: k = \frac{2 \pi}{\lambda} where :math:`\lambda` is the wavelength in meters. Args: wavelength (torch.Tensor): Wavelength in meters. Returns: torch.Tensor: Wavenumber. """ return 2 * torch.pi / wavelength
[docs] def transfer_function_fresnel(M: int, N: int, pixel_size: float, wavelengths: torch.Tensor, distance: float, device=torch.device('cpu')) -> torch.Tensor: r""" The transfer function for the Fresnel propagation can be written as follows: .. math:: H(f_x, f_y, \lambda) = e^{j k s \left(1 - \frac{\lambda^2}{2} (f_x^2 + f_y^2)\right)} where :math:`f_x` and :math:`f_y` are the spatial frequencies, :math:`\lambda` is the wavelength, :math:`s` is the distance of propagation and :math:`k` is the wavenumber. Args: M (int): Resolution at Y axis in pixels. N (int): Resolution at N axis in pixels. pixel_size (float): Pixel pixel_size in meters. wavelengths (torch.Tensor): Wavelengths in meters. distance (float): Distance in meters. device (torch.device): Device, for more see torch.device(). Returns: torch.Tensor: Complex kernel in Fourier domain with shape (len(wavelengths), M, N). """ fr,_ = get_spatial_coords(M, N, 1/(N*pixel_size), device=device, type='polar') fr = fr.unsqueeze(0) H = torch.exp(1j * wave_number(wavelengths) * distance * (1 - ((fr**2) * (wavelengths**2)/2)) ) return H
[docs] def transfer_function_angular_spectrum(M: int, N: int, pixel_size: float, wavelengths: torch.Tensor, distance: float, device: torch.device=torch.device('cpu')) -> torch.Tensor: r""" The transfer function for the angular spectrum propagation can be written as follows: .. math:: H(f_x, f_y, \lambda) = e^{\frac{j s 2 \pi}{\lambda} \sqrt{1 - \lambda^2 (f_x^2 + f_y^2)}} where :math:`f_x` and :math:`f_y` are the spatial frequencies, :math:`\lambda` is the wavelength, :math:`s` is the distance of propagation and :math:`k` is the wavenumber. Args: M (int): Resolution at Y axis in pixels. N (int): Resolution at X axis in pixels. pixel_size (float): Pixel pixel_size in meters. wavelengths (torch.Tensor): Wavelengths in meters. distance (float): Distance in meters. device (torch.device): Device, for more see torch.device(). type (str): Type of coordinates, can be "cartesian" or "polar". Returns: torch.Tensor: Complex kernel in Fourier domain with shape (len(wavelengths), M, N). """ fr,_ = get_spatial_coords(M, N, 1/(N*pixel_size), device=device, type='polar') fr = fr.unsqueeze(0) H = torch.exp(1j * wave_number(wavelengths) * distance * (1 - ((fr**2) * wavelengths**2)) ** 0.5) return H
[docs] def fraunhofer_propagation(field: torch.Tensor, M: int, N: int, pixel_size: float, wavelengths: torch.Tensor, distance: float, device: torch.device=torch.device('cpu'))-> torch.Tensor: r""" The Fraunhofer approximation of :math:`U_0(x',y')` is its Fourier transform, :math:`\mathcal{F}\{U_0\}` with an additional phase factor that depends on the distance of propagation, :math:`z`. The Fraunhofer approximation is given by the following equation: .. math:: U(x,y,z) \approx \frac{e^{jkz}e^{\frac{jk\left(x^2+y^2\right)}{2z}}}{j\lambda z} \mathcal{F}\left\{U_0(x,y)\right\}\left(\frac{x}{\lambda z}, \frac{y}{\lambda z}\right) where :math:`U(x,y,z)` is the field at distance :math:`z` from the source, :math:`U_0(x,y)` is the field at the source, :math:`\mathcal{F}` is the Fourier transform operator, :math:`k` is the wavenumber, :math:`\lambda` is the wavelength, :math:`\frac{x}{\lambda z}` and :math:`\frac{y}{\lambda z}` are the spatial frequencies, and :math:`z` is the distance of propagation. Args: field (torch.Tensor): Input field. M (int): Resolution at Y axis in pixels. N (int): Resolution at X axis in pixels. pixel_size (float): Pixel pixel_size in meters. wavelengths (torch.Tensor): Wavelengths in meters. distance (float): Distance in meters. device (torch.device): Device, for more see torch.device(). Returns: torch.Tensor: Propagated field. """ r, _ = get_spatial_coords(M, N, pixel_size, device=device, type='polar') r = r.unsqueeze(0) c = torch.exp(1j * wave_number(wavelengths) * distance) / (1j * wavelengths * distance) * torch.exp(1j * wave_number(wavelengths) / (2 * distance) * (r**2)) c = c.to(device=device) result = fft(field) * c * pixel_size**2 return result
[docs] def fraunhofer_inverse_propagation(field: torch.Tensor, pixel_size: float, wavelengths: torch.Tensor, distance: float, device: torch.device=torch.device('cpu')) -> torch.Tensor: r""" The inverse Fraunhofer approximation (to reconstruct the field at the source from the field at the sensor) is given by the following equation: .. math:: U_0(x,y) \approx \frac{1}{j\lambda z} e^{j k z} e^{\frac{j k (x^2 + y^2)}{2z}} \mathcal{F}^{-1}\left\{ U(x,y) \right\} where :math:`U_0(x,y)` is the field at the source, :math:`U(x,y)` is the field at the sensor, :math:`\mathcal{F}^{-1}` is the inverse Fourier transform operator, :math:`k` is the wavenumber, :math:`\lambda` is the wavelength, and :math:`z` is the distance of propagation. Args: field (torch.Tensor): Field at the sensor. pixel_size (float): Pixel pixel_size in meters. wavelengths (torch.Tensor): Wavelengths in meters. distance (float): Distance in meters. device (torch.device): Device, for more see torch.device(). Returns: torch.Tensor: Reconstructed field. """ _, M, N = field.shape r, _ = get_spatial_coords(M, N, pixel_size, device=device, type='polar') r = r.unsqueeze(0) c = torch.exp(1j * wave_number(wavelengths) * distance) / (1j * wavelengths * distance) * torch.exp(1j * wave_number(wavelengths) / (2 * distance) * r**2) c = c.to(device=device) result = ifft(field / pixel_size**2 / c) return result
[docs] def fft(field: torch.Tensor, axis = (-2, -1)) -> torch.Tensor: r""" Fast Fourier Transform of an optical field Args: field (torch.Tensor): Input field. axis (tuple): Tuple with the axes to perform the fft. Returns: torch.Tensor: Fourier transform of the input field. """ field = torch.fft.fftshift(field, dim=axis) field = torch.fft.fft2(field, dim=axis) field = torch.fft.fftshift(field, dim=axis) return field
[docs] def ifft(field: torch.Tensor, axis = (-2, -1)) -> torch.Tensor: r""" Inverse Fast Fourier Transform of an optical field Args: field (torch.Tensor): Input field. axis (tuple): Tuple with the axes to perform the ifft. Returns: torch.Tensor: Inverse Fourier transform of the input field. """ field = torch.fft.ifftshift(field, dim=axis) field = torch.fft.ifft2(field, dim=axis) field = torch.fft.ifftshift(field, dim=axis) return field
[docs] def scalar_diffraction_propagation(field: torch.Tensor, distance: float, pixel_size: float, wavelength: list, approximation: str) -> torch.Tensor: r""" Compute the optical field propagation using a scalar diffraction theory model which is given by the specific approximation selected. .. note:: * if 'approximation' is 'fresnel', the transfer function is calculated using :func:`colibri.optics.functional.transfer_function_fresnel`. * if 'approximation' is 'angular_spectrum', the transfer function is calculated using :func:`colibri.optics.functional.transfer_function_angular_spectrum`. * if 'approximation' is 'fraunhofer', the transfer function is calculated using :func:`colibri.optics.functional.fraunhofer_propagation`. * if 'approximation' is 'fraunhofer_inverse', the transfer function is calculated using :func:`colibri.optics.functional.fraunhofer_inverse_propagation`. Args: field (torch.Tensor): Input optical field of shape (C, M, N). distance (float): Propagation distance in meters. pixel_size (float): Pixel size in meters. wavelength (list): List of wavelengths in meters. approximation (str): Approximation (or diffraction model type) to use, can be "fresnel", "angular_spectrum" or "fraunhofer". Returns: torch.Tensor: The propagated optical field according to the selected approximation of shape (C, M, N). """ _, M, N = field.shape if approximation == "fresnel": H = transfer_function_fresnel(M, N, pixel_size, wavelength, distance, field.device) elif approximation == "angular_spectrum": H = transfer_function_angular_spectrum(M, N, pixel_size, wavelength, distance, field.device) elif approximation == "fraunhofer": return fraunhofer_propagation(field, M, N, pixel_size, wavelength, distance, field.device) elif approximation == "fraunhofer_inverse": return fraunhofer_inverse_propagation(field, pixel_size, wavelength, distance, field.device) else: raise NotImplementedError(f"{approximation} approximation is implemented") U1 = fft(field) U2 = U1 * H result = ifft(U2) return result
[docs] def circular_aperture(M: int, N: int, radius: float, pixel_size: float) -> torch.Tensor: r''' Create a circular aperture mask of a given radius and pixel_size of size (M, N). .. math:: A(x, y) = \begin{cases} 1 & \text{if } \pm \sqrt{(x^2 + y^2)} \leq \text{radius} \\ 0 & \text{otherwise} \end{cases} where: :math:`(x, y)` are the coordinates of each pixel, normalized by the pixel size, :math:`\text{radius}` is the radius of the aperture Args: M (int): Resolution at Y axis in pixels. N (int): Resolution at X axis in pixels. radius (float): Radius of the circular aperture in meters. pixel_size (float): Pixel size in meters. Returns: torch.Tensor: A binary mask with 1's inside the radius and 0's outside. ''' r, _ = get_spatial_coords(M, N, pixel_size, type='polar') return r<=radius
[docs] def height2phase(height_map: torch.Tensor, wavelengths: torch.Tensor, refractive_index: callable) -> torch.Tensor: r""" Convert height map to phase modulation. .. math:: \Phi(x,y,\lambda) = k(\lambda) \Delta n(\lambda) h(x, y) where :math:`\Phi` is the phase modulation, :math:`h(x, y)` is the height map of the optical element, :math:`k(\lambda)` is the wavenumber for \lambda wavelength and :math:`\Delta n(\lambda)` is the change of refractive index between propagation medium and material of the optical element . Args: height_map (torch.Tensor): Height map. wavelengths (torch.Tensor): Wavelengths in meters. refractive_index (function): Function to calculate the refractive index. Returns: torch.Tensor: Phase. """ k0 = wave_number(wavelengths) phase_doe = refractive_index(wavelengths) * k0 * height_map return phase_doe
[docs] def psf_single_doe_spectral(height_map: torch.Tensor, aperture: torch.Tensor, refractive_index: callable, wavelengths: torch.Tensor, source_distance: float, sensor_distance:float, pixel_size: float, approximation:str = "fresnel") -> torch.Tensor: r""" Calculate the point spread function (PSF) of an optical system comprising a diffractive optical element (DOE) for spectral imaging. The PSF is calculated as follows: .. math:: \mathbf{H}(\learnedOptics) = |\mathcal{P_2}(z_2, \lambda) \left( \mathcal{P_1}(z_1, \lambda)(\delta) * \learnedOptics \right)|^2 where :math:`\mathcal{P_1}` is an operator that describes the propagation of light from the source to the DOE at a distance :math:`z_1`, :math:`\mathcal{P_2}` is an operator that describes the propagation of light from the DOE to the sensor at a distance :math:`z_2`, and :math:`\learnedOptics` is the DOE. The operator :math:`\mathcal{P_2}` depends on the given approximation: - Fresnel :func:`colibri.optics.functional.transfer_function_fresnel` - Angular Spectrum :func:`colibri.optics.functional.transfer_function_angular_spectrum` - Fraunhofer :func:`colibri.optics.functional.fraunhofer_propagation`. Args: height_map (torch.Tensor): Height map of the DOE. aperture (torch.Tensor): Aperture mask. refractive_index (callable): Function to calculate the refractive index. wavelengths (torch.Tensor): Wavelengths in meters. source_distance (float): Distance from the source to the DOE in meters. sensor_distance (float): Distance from the DOE to the sensor in meters. pixel_size (float): Pixel size in meters. approximation (str): Type of propagation model ('fresnel', 'angular_spectrum', 'fraunhofer'). Returns: torch.Tensor: PSF of the optical system, normalized to unit energy. """ height_map = height_map*aperture M, N = height_map.shape wavelengths = wavelengths.unsqueeze(1).unsqueeze(2) k0 = wave_number(wavelengths) doe = height2phase(height_map = torch.unsqueeze(height_map, 0), wavelengths = wavelengths, refractive_index = refractive_index) doe = torch.exp(1j * doe*aperture)*aperture optical_field = torch.ones_like(doe) if not(np.isinf(source_distance) or np.isnan(source_distance)): r, _ = get_spatial_coords(M, N, pixel_size, device=doe.device, type='polar') spherical_phase = (k0/(2*source_distance))*(torch.unsqueeze(r, 0)**2) optical_field = torch.exp(1j*spherical_phase) * (1/(1j*wavelengths*source_distance)) optical_field = optical_field * doe optical_field_in_sensor = scalar_diffraction_propagation(field = optical_field, distance = sensor_distance, pixel_size = pixel_size, wavelength = wavelengths, approximation = approximation) psf = torch.abs(optical_field_in_sensor)**2 psf = psf/torch.sum(psf, dim=(-2, -1), keepdim=True) return psf
[docs] def gaussian_noise(y: torch.Tensor, snr: float) -> torch.Tensor: r""" Add Gaussian noise to an image based on a specified signal-to-noise ratio (SNR). .. math:: \mathbf{\tilde{y}} = \mathbf{y} + n where :math:`n` is a Gaussian noise with zero mean and variance :math:`\sigma^2` such that is derived from the SNR and the power of :math:`\mathbf{y}`. Args: y (torch.Tensor): Original image tensor with shape (B, L, M, N). snr (float): Desired signal-to-noise ratio in decibels (dB). Returns: torch.Tensor: Noisy image tensor with the same shape as input. """ noise = torch.zeros_like(y) sigma_per_channel = torch.sum(torch.pow(y, 2), dim=(2, 3), keepdim=True) / (torch.numel(y[0,0,...]) * 10 ** (snr / 10)) noise = torch.randn_like(y) * torch.sqrt(sigma_per_channel) return y+noise
[docs] def fourier_conv(image: torch.Tensor, psf: torch.Tensor) -> torch.Tensor: r""" Apply Fourier convolution theorem to simulate the effect of a linear system characterized by a point spread function (PSF). .. math:: g = \mathcal{F}^{-1}(\mathcal{F}(f) * \mathcal{F}(h)) where :math:`f` is the input image, :math:`h` is the PSF, :math:`g` is the convolved output, :math:`\mathcal{F}` and :math:`\mathcal{F}^{-1}` denote the Fourier and inverse Fourier transforms. Args: image (torch.Tensor): Image to simulate the sensing (B, L, M, N) psf (torch.Tensor): Point Spread Function (1, L, M, N) Returns: torch.Tensor: Measurement (B, 1, M, N) """ # Fix psf and image size psf_size = psf.shape[-2:] image_size = image.shape[-2:] extra_size = [(psf_size[i]-image_size[i]) for i in range(len(image_size))] if extra_size[0] < 0 or extra_size[1] < 0: psf = add_pad(psf, [0, -extra_size[0]//2, -extra_size[1]//2]) else: image = add_pad(image, [0, 0, extra_size[0]//2, extra_size[1]//2]) img_fft = fft(image) otf = fft(psf) img_fft = img_fft * otf img = torch.abs(ifft(img_fft)) if not(extra_size[0] < 0 or extra_size[1] < 0): img = unpad(img, pad = [0, 0, extra_size[0]//2, extra_size[1]//2]) return img
[docs] def add_pad(x: torch.Tensor, pad: list) -> torch.Tensor: r""" Add zero padding to a tensor. .. note:: * pad object is a list of the same length as the number of dimensions of the tensor x * each element of the pad list is a integer, specifying the amount of padding to add on each side of the corresponding dimension in x. Args: x (torch.Tensor): Tensor to pad pad list: padding to ad Returns: x (torch.Tensor): Padded tensor Example: >>> x = torch.tensor([[1, 2], [3, 4]]) >>> add_pad(x, [1, 1]) tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) """ assert len(x.shape) == len(pad), "The tensor and the padding must have the same number of dimensions" pad_list = sum([[pa, pa] for pa in pad][::-1], []) return torch.nn.functional.pad(x, pad_list, mode='constant', value=0)
[docs] def unpad(x: torch.Tensor, pad: list) -> torch.Tensor: r""" Unpad a tensor. .. note:: * pad is a list of the same length as the number of dimensions of the tensor x * each element of the pad list is a integer, specifying the amount of padding to remove on each side of the corresponding dimension in x. Args: x (torch.Tensor): Tensor to unpad pad int: padding to remove Returns: x (torch.Tensor): Unpadded tensor Example: >>> x = torch.tensor([[0, 0, 0, 0],[0, 1, 2, 0],[0, 3, 4, 0],[0, 0, 0, 0]]) >>> unpad(x, [1, 1]) tensor([[1, 2], [3, 4]]) """ assert len(x.shape) == len(pad), "The tensor and the padding must have the same number of dimensions" if len(x.shape) ==4: return x[(0+pad[0]):(x.shape[0] - pad[0]), (0+pad[1]):(x.shape[1] - pad[1]), (0+pad[2]):(x.shape[2] - pad[2]), (0+pad[3]):(x.shape[3] - pad[3])] elif len(x.shape) ==3: return x[(0+pad[0]):(x.shape[0] - pad[0]), (0+pad[1]):(x.shape[1] - pad[1]), (0+pad[2]):(x.shape[2] - pad[2])] elif len(x.shape) ==2: return x[(0+pad[0]):(x.shape[0] - pad[0]), (0+pad[1]):(x.shape[1] - pad[1])] else: raise ValueError("The tensor must have 2, 3 or 4 dimensions")
[docs] def signal_conv(image: torch.Tensor, psf: torch.Tensor) -> torch.Tensor: r""" This function applies the convolution of an image with a Point Spread Function (PSF) in the signal domain. .. math:: g(x, y) = f(x, y) * h(x, y) where :math:`f(x, y)` is the input image, :math:`h(x, y)` is the PSF, and :math:`g(x, y)` is the convolved output. Args: image (torch.Tensor): Image to simulate the sensing (B, L, M, N) psf (torch.Tensor): Point Spread Function (1, L, M, N) Returns: torch.Tensor: Measurement (B, 1, M, N) """ original_size = image.shape[-2:] psf = psf.unsqueeze(1) C, _, _, N = psf.shape image = torch.nn.functional.conv2d(image, torch.flip(psf, [-2, -1]), padding=(N-1, N-1), groups=C) new_size = image.shape[-2:] image = unpad(image, pad = [0, 0, (new_size[0]-original_size[0])//2, (new_size[1]-original_size[1])//2]) return image
[docs] def convolutional_sensing(image: torch.Tensor, psf: torch.Tensor, domain='fourier') -> torch.Tensor: r""" Simulate the convolutional sensing model of an optical system, using either Fourier or spatial domain methods. The "domain" argument choose to perform the convolution in the Fourier domain with :func:`colibri.optics.functional.fourier_conv` or the spatial domain with :func:`colibri.optics.functional.signal_conv`. Args: image (torch.Tensor): Image tensor to simulate the sensing (B, L, M, N). psf (torch.Tensor): Point Spread Function (PSF) tensor (1, L, M, N). domain (str): Domain for convolution operation, 'fourier' or 'signal'. Returns: torch.Tensor: Convolved image tensor as measurement (B, 1, M, N). Raises: NotImplementedError: If the specified domain is not supported. """ if domain == 'fourier': return fourier_conv(image, psf) elif domain == 'signal': return signal_conv(image, psf) else: raise NotImplementedError(f"{domain} domain is not implemented")
[docs] def wiener_filter(image: torch.Tensor, psf: torch.Tensor, alpha: float) -> torch.Tensor: r""" This function applies the Wiener filter to an image. .. math:: \begin{aligned} X(x, y) &= \mathcal{F}^{-1}\{Y(u, v) \frac{H^*(u, v)}{|H(u, v)|^2 + \alpha}\} \end{aligned} where :math:`H(u, v)` is the optical transfer function, :math:`Y(u, v)` is the Fourier transform of the image, :math:`\alpha` is the regularization parameter, and :math:`X(x, y)` is the filtered image. Args: image (torch.Tensor): Image to apply the Wiener filter (B, L, M, N) psf (torch.Tensor): Point Spread Function (1, L, M, N) alpha (float): Regularization parameter Returns: torch.Tensor: Filtered image (B, L, M, N) """ # Fix psf and image size psf_size = psf.shape[-2:] image_size = image.shape[-2:] extra_size = [(psf_size[i]-image_size[i]) for i in range(len(image_size))] if extra_size[0] < 0 or extra_size[1] < 0: psf = add_pad(psf, [0, -extra_size[0]//2, -extra_size[1]//2]) else: image = add_pad(image, [0, 0, extra_size[0]//2, extra_size[1]//2]) img_fft = fft(image) otf = fft(psf) filter = torch.conj(otf) / (torch.abs(otf) ** 2 + alpha) img_fft = img_fft * filter img = torch.abs(ifft(img_fft)) if not(extra_size[0] < 0 or extra_size[1] < 0): img = unpad(img, pad = [0, 0, extra_size[0]//2, extra_size[1]//2]) return img
[docs] def ideal_panchromatic_sensor(image: torch.Tensor) -> torch.Tensor: r""" Simulate the response of an ideal panchromatic sensor by averaging the spectral channels. .. math:: I = \frac{1}{L} \sum_{\lambda} I_{\lambda} where :math:`I_{\lambda}` is the intensity at each wavelength, and :math:`L` is the number of spectral channels. Args: image (torch.Tensor): Multispectral image tensor (B, L, M, N). Returns: torch.Tensor: Simulated sensor output as measurement (B, 1, M, N). """ return torch.sum(image, dim=1, keepdim=True)/image.shape[1]