LFSI
- class colibri.recovery.lfsi.LFSI(max_iters, filter=None, p=0.6, k_size=5, sigma=1.0, train_filter=False, dtype=torch.float32, device=device(type='cpu'))[source]
Bases:
Module
Learned Filtered Spectral Initialization (LFSI).
This layer implements the Learned Filtered Spectral Initialization method for estimating the optical field from coded measurements of a Coded Phase Imaging System.
This algorithm approximates an initial guess of the optical field \(\mathbf{x}^{(0)}\) by computing a filtered version of the leading eigenvector of the sensing matrix taking into account the measurements. For more information refer to: Learning spectral initialization for phase retrieval via deep neural networks https://doi.org/10.1364/AO.445085
Initializes the Filtered Spectral Initialization layer.
- Parameters:
max_iters (int) – The maximum number of iterations.
filter (torch.Tensor, optional) – The filter tensor. If None, a Gaussian kernel will be used. Default is None.
p (float, optional) – The percentage of the top values to keep. Default is 0.6.
k_size (int, optional) – The size of the Gaussian kernel. Default is 5.
sigma (float, optional) – The standard deviation of the Gaussian kernel. Default is 1.0.
train_filter (bool, optional) – If True, the filter will be trainable. Default is False.
dtype (torch.dtype, optional) – The data type of the filter tensor. Default is torch.float32.
device (torch.device, optional) – The device on which the filter tensor will be allocated. Default is CPU.
Returns
None
- get_ytr(y)[source]
This function generates a mask that keeps the top p% of the values of the input tensor.
- Parameters:
y (torch.Tensor) – The input tensor.
- Returns:
The mask tensor
- Return type:
torch.Tensor
- forward(inputs, optical_layer)[source]
This function estimates the optical field from coded measurements of a Coded Phase Imaging System using the Filtered Spectral Initialization method.
- Parameters:
inputs (torch.Tensor) – The input tensor.
optical_layer (torch.nn.Module) – The optical layer that represents the Coded Phase Imaging System.
- Returns:
The estimated optical field.
- Return type:
torch.Tensor