Fista

class colibri.recovery.fista.Fista(acquistion_model, fidelity=L2(), prior=Sparsity(), max_iters=5, alpha=0.001, _lambda=0.1)[source]

Bases: Module

Fast Iterative Shrinkage-Thresholding Algorithm (FISTA)

The FISTA algorithm solves the optimization problem:

\[\begin{equation} \underset{\mathbf{x}}{\text{arg min}} \quad \frac{1}{2}||\mathbf{y} - \forwardLinear (\mathbf{x})||^2 + \lambda||\mathbf{x}||_1 \end{equation}\]

where \(\forwardLinear\) is the forward model, \(\mathbf{y}\) is the data to be reconstructed, \(\lambda\) is the regularization parameter and \(||\cdot||_1\) is the L1 norm.

The FISTA algorithm is an iterative algorithm that solves the optimization problem by performing a gradient step and a proximal step.

\[\begin{split}\begin{align*} \mathbf{x}_{k+1} &= \text{prox}_{\lambda||\cdot||_1}( \mathbf{z}_k - \alpha \nabla f( \mathbf{z}_k)) \\ t_{k+1} &= \frac{1 + (1 + 4t_k^2)^{0.5}}{2} \\ \mathbf{z}_{k+1} &= \mathbf{x}_{k+1} + \frac{t_k-1}{t_{k+1}}( \mathbf{x}_{k} - \mathbf{x}_{k-1}) \end{align*}\end{split}\]

where \(\alpha\) is the step size and \(f\) is the fidelity term.

Implementation based on the formulation of authors in https://doi.org/10.1137/080716542

Parameters:
  • fidelity (nn.Module) – The fidelity term in the optimization problem. This is a function that measures the discrepancy between the data and the model prediction.

  • prior (nn.Module) – The prior term in the optimization problem. This is a function that encodes prior knowledge about the solution.

  • acquistion_model (nn.Module) – The acquisition model of the imaging system. This is a function that models the process of data acquisition in the imaging system.

  • max_iters (int) – The maximum number of iterations for the FISTA algorithm. Defaults to 5.

  • alpha (float) – The step size for the gradient step. Defaults to 1e-3.

  • _lambda (float) – The regularization parameter for the prior term. Defaults to 0.1.

Returns:

None

forward(y, x0=None, verbose=False)[source]

Runs the FISTA algorithm to solve the optimization problem.

Parameters:
  • y (torch.Tensor) – The measurement data to be reconstructed.

  • x0 (torch.Tensor, optional) – The initial guess for the solution. Defaults to None.

Returns:

The reconstructed image.

Return type:

torch.Tensor