PnP_ADMM
- class colibri.recovery.pnp.PnP_ADMM(acquisition_model, fidelity=L2(), prior=Sparsity(), solver='close', max_iters=20, _lambda=0.1, rho=0.1, alpha=0.01)[source]
Bases:
Module
Plug-and-Play (PnP) algorithm with Alternating Direction Method of Multipliers (ADMM) formulation.
The PnP algorithm solves the optimization problem:
\[\begin{equation} \underset{\mathbf{x}}{\text{arg min}} \quad \frac{1}{2}||\mathbf{y} - \forwardLinear (\mathbf{x})||^2 + \lambda \mathcal{R}(\mathbf{x}) \end{equation}\]where \(\forwardLinear\) is the forward model, \(\mathbf{y}\) is the data to be reconstructed, \(\lambda\) is the regularization parameter and \(\mathcal{R}\) is the prior regularizer.
The PnP ADMM algorithm is an iterative algorithm that solves the optimization problem by performing a closed-form solution for the forward model and a proximal/denoising step for the prior term.
\[\begin{split}\begin{align*} \mathbf{x}_{k+1} &= \underset{\textbf{x}}{\text{arg min}} \quad \frac{1}{2}||\mathbf{y} - \forwardLinear (\mathbf{x})||^2 + \frac{\rho}{2} || \mathbf{x} - \tilde{\mathbf{x}} ||^2 \\ \mathbf{v}_{k+1} &= \text{prox}_{\mathcal{R}}( \tilde{\mathbf{v}}_k , \lambda) \\ \mathbf{u}_{k+1} &= \mathbf{u}_{k} + \mathbf{x}_{k+1} - \mathbf{v}_{k+1} \end{align*}\end{split}\]with \(\tilde{\mathbf{x}} = \mathbf{v}_{k} - \mathbf{u}_{k}\) and \(\tilde{\mathbf{v}} = \mathbf{x}_{k+1} + \mathbf{u}_{k}\), and \(\rho\) is the regularization influence in the \(\mathbf{x}\) update and \(\lambda\) is the regularization parameter for the prior term.
Implementation based on the formulation of authors in https://doi.org/10.1109/TCI.2016.2629286
- Parameters:
acquisition_model (nn.Module) – The acquisition model of the imaging system. This is a function that models the process of data acquisition in the imaging system.
fidelity (nn.Module) – The fidelity term in the optimization problem. This is a function that measures the discrepancy between the data and the model prediction.
prior (nn.Module) – The prior term in the optimization problem. This is a function that encodes prior knowledge about the solution.
max_iters (int) – The maximum number of iterations for the FISTA algorithm. Defaults to 20.
_lambda (float) – The regularization parameter for the prior term. Defaults to 0.1.
rho (float) – The penalty parameter for the ADMM formulation. Defaults to 0.1.
alpha (float) – The step size for the gradient step. Defaults to 1e-3.
- Returns:
None
- forward(y, x0=None, verbose=False)[source]
Runs the FISTA algorithm to solve the optimization problem.
- Parameters:
y (torch.Tensor) – The measurement data to be reconstructed.
x0 (torch.Tensor, optional) – The initial guess for the solution. Defaults to None.
- Returns:
The reconstructed image.
- Return type:
torch.Tensor