Source code for colibri.recovery.pnp

import torch
from torch import nn

from .solvers import SOLVERS, get_solver

from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.prior import Sparsity

[docs] class PnP_ADMM(nn.Module): r""" Plug-and-Play (PnP) algorithm with Alternating Direction Method of Multipliers (ADMM) formulation. The PnP algorithm solves the optimization problem: .. math:: \begin{equation} \underset{\mathbf{x}}{\text{arg min}} \quad \frac{1}{2}||\mathbf{y} - \forwardLinear (\mathbf{x})||^2 + \lambda \mathcal{R}(\mathbf{x}) \end{equation} where :math:`\forwardLinear` is the forward model, :math:`\mathbf{y}` is the data to be reconstructed, :math:`\lambda` is the regularization parameter and :math:`\mathcal{R}` is the prior regularizer. The PnP ADMM algorithm is an iterative algorithm that solves the optimization problem by performing a closed-form solution for the forward model and a proximal/denoising step for the prior term. .. math:: \begin{align*} \mathbf{x}_{k+1} &= \underset{\textbf{x}}{\text{arg min}} \quad \frac{1}{2}||\mathbf{y} - \forwardLinear (\mathbf{x})||^2 + \frac{\rho}{2} || \mathbf{x} - \tilde{\mathbf{x}} ||^2 \\ \mathbf{v}_{k+1} &= \text{prox}_{\mathcal{R}}( \tilde{\mathbf{v}}_k , \lambda) \\ \mathbf{u}_{k+1} &= \mathbf{u}_{k} + \mathbf{x}_{k+1} - \mathbf{v}_{k+1} \end{align*} with :math:`\tilde{\mathbf{x}} = \mathbf{v}_{k} - \mathbf{u}_{k}` and :math:`\tilde{\mathbf{v}} = \mathbf{x}_{k+1} + \mathbf{u}_{k}`, and :math:`\rho` is the regularization influence in the :math:`\mathbf{x}` update and :math:`\lambda` is the regularization parameter for the prior term. Implementation based on the formulation of authors in https://doi.org/10.1109/TCI.2016.2629286 """ def __init__(self, acquisition_model, fidelity=L2(), prior=Sparsity("dct"), solver="close", max_iters=20, _lambda=0.1, rho=0.1, alpha=0.01): r""" Args: acquisition_model (nn.Module): The acquisition model of the imaging system. This is a function that models the process of data acquisition in the imaging system. fidelity (nn.Module): The fidelity term in the optimization problem. This is a function that measures the discrepancy between the data and the model prediction. prior (nn.Module): The prior term in the optimization problem. This is a function that encodes prior knowledge about the solution. max_iters (int): The maximum number of iterations for the FISTA algorithm. Defaults to 20. _lambda (float): The regularization parameter for the prior term. Defaults to 0.1. rho (float): The penalty parameter for the ADMM formulation. Defaults to 0.1. alpha (float): The step size for the gradient step. Defaults to 1e-3. Returns: None """ super(PnP_ADMM, self).__init__() self.fidelity = fidelity self.acquisition_model = acquisition_model self.prior = prior self.solver = solver self.max_iters = max_iters self._lambda = _lambda self.rho = rho self.alpha = alpha self.H = lambda x: self.acquisition_model.forward(x)
[docs] def forward(self, y, x0=None, verbose=False): r"""Runs the FISTA algorithm to solve the optimization problem. Args: y (torch.Tensor): The measurement data to be reconstructed. x0 (torch.Tensor, optional): The initial guess for the solution. Defaults to None. Returns: torch.Tensor: The reconstructed image. """ # Initialize the solution if self.solver == "close": ClosedSolution = get_solver(self.acquisition_model) x_solver = ClosedSolution(y, self.acquisition_model) else: GradientDescent = lambda x, xt: x - self.alpha * self.fidelity.grad(x, y, self.H) - self.rho * (x - xt) x_solver = GradientDescent if x0 is None: x0 = torch.zeros_like(y) u_t = torch.zeros_like(x0) v_t = x0 for i in range(self.max_iters): # x-subproblem update xtilde = v_t - u_t x_t = x_solver(xtilde, self.rho) if self.solver == "close" else x_solver(x_t, xtilde) # v-subproblem update vtilde = x_t + u_t v_t = self.prior.prox(vtilde, self._lambda) # u-subproblem update u_t = u_t + x_t - v_t if verbose: error = self.fidelity.forward(x_t, y, self.H).item() print("Iter: ", i, "fidelity: ", error) return v_t