Source code for colibri.misc.e2e

import torch.nn as nn


[docs] class E2E(nn.Module): def __init__(self, optical_layer: nn.Module, decoder: nn.Module): r""" End-to-end model (E2E) for image reconstruction from compressed measurements. In E2E models, the optical system and the computational decoder are modeled as layers of a neural network, denoted as :math:`\forwardLinear_{\learnedOptics}` and :math:`\reconnet`, respectively. The optimization problem is formulated as following: .. math:: \begin{equation} \begin{aligned} \{ \learnedOptics^*, \theta^* \} &= \arg \min_{\learnedOptics, \theta} \mathcal{L}(\learnedOptics, \theta) \\ & \coloneqq \sum_{p=1}^P \mathcal{L}_{\text{task}} (\reconnet(\forwardLinear_\learnedOptics(\mathbf{x}_p)), \mathbf{x}_p) + \lambda \mathcal{R}_1(\learnedOptics) + \mu \mathcal{R}_2(\theta) \end{aligned} \end{equation} where the optimal set of parameters of the optical system and the decoder are represented by :math:`\{\learnedOptics^*, \theta^*\}`. The training dataset consists of input-output pairs :math:`\{\mathbf{x}_p, \mathbf{y}_p\}_{p=1}^P`. The loss function :math:`\mathcal{L}_{\text{task}}` is selected based on the task, usually using mean square error (MSE) for reconstruction and cross-entropy for classification. Regularization functions :math:`\mathcal{R}_1(\learnedOptics)` and :math:`\mathcal{R}_2(\theta)`, with parameters :math:`\lambda` and :math:`\mu`, are used to prevent overfitting in the optical system and the decoder. For more information refer to: Deep Optical Coding Design in Computational Imaging: A data-driven framework https://doi.org/10.1109/MSP.2022.3200173 Args: optical_layer (nn.Module): Optical Layer module. decoder (nn.Module): Computational decoder module. """ super(E2E, self).__init__() self.optical_layer = optical_layer self.decoder = decoder
[docs] def forward(self, x): r""" Forward pass of the E2E model. .. math:: \begin{equation} \begin{aligned} \mathbf{y} &= \forwardLinear_{\learnedOptics}(\mathbf{x}) \\ \mathbf{x}_{\text{init}} &= \forwardLinear_{\learnedOptics}^\top(\forwardLinear_{\learnedOptics}(\mathbf{y}))\\ \hat{\mathbf{x}} &=\reconnet(\mathbf{x}_{\text{init}}) \end{aligned} \end{equation} Args: x (torch.Tensor): Input tensor with shape (B, L, M, N). Returns: torch.Tensor: Output tensor with shape (B, L, M, N). """ y = self.optical_layer(x) # y = A(x) x_init = self.optical_layer(y, type_calculation="backward") # x_init = A^T(y) x_hat = self.decoder(x_init) # x_hat = R(x_init) return x_hat