E2E

class colibri.misc.e2e.E2E(optical_layer, decoder)[source]

Bases: Module

End-to-end model (E2E) for image reconstruction from compressed measurements.

In E2E models, the optical system and the computational decoder are modeled as layers of a neural network, denoted as \(\forwardLinear_{\learnedOptics}\) and \(\reconnet\), respectively. The optimization problem is formulated as following:

\[\begin{split}\begin{equation} \begin{aligned} \{ \learnedOptics^*, \theta^* \} &= \arg \min_{\learnedOptics, \theta} \mathcal{L}(\learnedOptics, \theta) \\ & \coloneqq \sum_{p=1}^P \mathcal{L}_{\text{task}} (\reconnet(\forwardLinear_\learnedOptics(\mathbf{x}_p)), \mathbf{x}_p) + \lambda \mathcal{R}_1(\learnedOptics) + \mu \mathcal{R}_2(\theta) \end{aligned} \end{equation}\end{split}\]

where the optimal set of parameters of the optical system and the decoder are represented by \(\{\learnedOptics^*, \theta^*\}\). The training dataset consists of input-output pairs \(\{\mathbf{x}_p, \mathbf{y}_p\}_{p=1}^P\). The loss function \(\mathcal{L}_{\text{task}}\) is selected based on the task, usually using mean square error (MSE) for reconstruction and cross-entropy for classification. Regularization functions \(\mathcal{R}_1(\learnedOptics)\) and \(\mathcal{R}_2(\theta)\), with parameters \(\lambda\) and \(\mu\), are used to prevent overfitting in the optical system and the decoder.

For more information refer to: Deep Optical Coding Design in Computational Imaging: A data-driven framework https://doi.org/10.1109/MSP.2022.3200173

Parameters:
  • optical_layer (nn.Module) – Optical Layer module.

  • decoder (nn.Module) – Computational decoder module.

forward(x)[source]

Forward pass of the E2E model.

\[\begin{split}\begin{equation} \begin{aligned} \mathbf{y} &= \forwardLinear_{\learnedOptics}(\mathbf{x}) \\ \mathbf{x}_{\text{init}} &= \forwardLinear_{\learnedOptics}^\top(\forwardLinear_{\learnedOptics}(\mathbf{y}))\\ \hat{\mathbf{x}} &=\reconnet(\mathbf{x}_{\text{init}}) \end{aligned} \end{equation}\end{split}\]
Parameters:

x (torch.Tensor) – Input tensor with shape (B, L, M, N).

Returns:

Output tensor with shape (B, L, M, N).

Return type:

torch.Tensor