Source code for colibri.recovery.solvers.core

import torch
from colibri.optics import BaseOpticsLayer


class Solver(object):
    r"""

    Base class for all solvers.
    """

    def __init__(self, y, acquisition_model: BaseOpticsLayer):
        r"""

        Initializes the solver.

        Args:
            y (torch.Tensor): Input tensor
            acquisition_model (BaseOpticsLayer): Acquisition model
        """
        pass

    def __call__(self, xtilde, rho):
        result = self.solve(xtilde, rho)
        return result

    def solve(self, xtilde, rho):
        raise NotImplementedError("Subclasses should implement the solve method.")


[docs] class L2L2Solver(Solver): r""" Base class for linear solvers. It describes the close-form solution of the optimization problem. .. math:: \min_{\textbf{x}} \frac{1}{2}||\textbf{y} - \textbf{H}\textbf{x}||_2^2 + \rho||\textbf{x} - \tilde{\textbf{x}}||_2^2 """ def __init__(self, y, acquisition_model: BaseOpticsLayer): r""" Args: y (torch.Tensor): Input tensor with shape (B, \*) acquisition_model (BaseOpticsLayer): Acquisition model """ super(L2L2Solver, self).__init__(y, acquisition_model) # vectorized form of y y_vec = y.view(y.size(0), -1) # batch matrix multiplication H = self.acquisition_model.get_sensing_matrix() H = H.unsqueeze(0).repeat(y.size(0), 1, 1) self.Hty = torch.bmm(H, y_vec.unsqueeze(-1)) self.HtH = torch.bmm(H, H.permute(0, 2, 1))
[docs] def solve(self, xtilde, rho): r""" Solves the optimization problem by computing the following expression: .. math:: \hat{\textbf{x}} = (\textbf{H}^\top\textbf{H} + \rho \textbf{I})^{-1}(\textbf{H}^\top\textbf{y} + \rho \tilde{\textbf{x}}) Args: xtilde (torch.Tensor): Regularizaton tensor rho (float): Regularization parameter Returns: torch.Tensor: Recovered tensor """ H_adjoint = self.HtH + rho * torch.eye(self.HtH.shape[1], device=self.HtH.device) H_adjoint = torch.inverse(H_adjoint) b, c, h, w = xtilde.size() x_hat = torch.bmm(H_adjoint, self.Hty + rho * xtilde.view(b, -1, 1)).view(b, c, h, w) return x_hat