import torch.nn as nn
[docs]
class E2E(nn.Module):
def __init__(self, optical_layer: nn.Module, decoder: nn.Module):
r"""
End-to-end model (E2E) for image reconstruction from compressed measurements.
In E2E models, the optical system and the computational decoder are modeled as layers of a neural network, denoted as :math:`\forwardLinear_{\learnedOptics}` and :math:`\reconnet`, respectively. The optimization problem is formulated as following:
.. math::
\begin{equation}
\begin{aligned}
\{ \learnedOptics^*, \theta^* \} &= \arg \min_{\learnedOptics, \theta} \mathcal{L}(\learnedOptics, \theta) \\
& \coloneqq \sum_{p=1}^P \mathcal{L}_{\text{task}} (\reconnet(\forwardLinear_\learnedOptics(\mathbf{x}_p)), \mathbf{x}_p) + \lambda \mathcal{R}_1(\learnedOptics) + \mu \mathcal{R}_2(\theta)
\end{aligned}
\end{equation}
where the optimal set of parameters of the optical system and the decoder are represented by :math:`\{\learnedOptics^*, \theta^*\}`. The training dataset consists of input-output pairs :math:`\{\mathbf{x}_p, \mathbf{y}_p\}_{p=1}^P`. The loss function :math:`\mathcal{L}_{\text{task}}` is selected based on the task, usually using mean square error (MSE) for reconstruction and cross-entropy for classification. Regularization functions :math:`\mathcal{R}_1(\learnedOptics)` and :math:`\mathcal{R}_2(\theta)`, with parameters :math:`\lambda` and :math:`\mu`, are used to prevent overfitting in the optical system and the decoder.
For more information refer to: Deep Optical Coding Design in Computational Imaging: A data-driven framework https://doi.org/10.1109/MSP.2022.3200173
Args:
optical_layer (nn.Module): Optical Layer module.
decoder (nn.Module): Computational decoder module.
"""
super(E2E, self).__init__()
self.optical_layer = optical_layer
self.decoder = decoder
[docs]
def forward(self, x):
r"""
Forward pass of the E2E model.
.. math::
\begin{equation}
\begin{aligned}
\mathbf{y} &= \forwardLinear_{\learnedOptics}(\mathbf{x}) \\
\mathbf{x}_{\text{init}} &= \forwardLinear_{\learnedOptics}^\top(\forwardLinear_{\learnedOptics}(\mathbf{y}))\\
\hat{\mathbf{x}} &=\reconnet(\mathbf{x}_{\text{init}})
\end{aligned}
\end{equation}
Args:
x (torch.Tensor): Input tensor with shape (B, L, M, N).
Returns:
torch.Tensor: Output tensor with shape (B, L, M, N).
"""
y = self.optical_layer(x) # y = A(x)
x_init = self.optical_layer(y, type_calculation="backward") # x_init = A^T(y)
x_hat = self.decoder(x_init) # x_hat = R(x_init)
return x_hat