Fista
- class colibri.recovery.fista.Fista(acquistion_model, fidelity=L2(), prior=Sparsity(), max_iters=5, alpha=0.001, _lambda=0.1)[source]
Bases:
Module
Fast Iterative Shrinkage-Thresholding Algorithm (FISTA)
The FISTA algorithm solves the optimization problem:
\[\begin{equation} \underset{\mathbf{x}}{\text{arg min}} \quad \frac{1}{2}||\mathbf{y} - \forwardLinear (\mathbf{x})||^2 + \lambda||\mathbf{x}||_1 \end{equation}\]where \(\forwardLinear\) is the forward model, \(\mathbf{y}\) is the data to be reconstructed, \(\lambda\) is the regularization parameter and \(||\cdot||_1\) is the L1 norm.
The FISTA algorithm is an iterative algorithm that solves the optimization problem by performing a gradient step and a proximal step.
\[\begin{split}\begin{align*} \mathbf{x}_{k+1} &= \text{prox}_{\lambda||\cdot||_1}( \mathbf{z}_k - \alpha \nabla f( \mathbf{z}_k)) \\ t_{k+1} &= \frac{1 + (1 + 4t_k^2)^{0.5}}{2} \\ \mathbf{z}_{k+1} &= \mathbf{x}_{k+1} + \frac{t_k-1}{t_{k+1}}( \mathbf{x}_{k} - \mathbf{x}_{k-1}) \end{align*}\end{split}\]where \(\alpha\) is the step size and \(f\) is the fidelity term.
Implementation based on the formulation of authors in https://doi.org/10.1137/080716542
- Parameters:
fidelity (nn.Module) – The fidelity term in the optimization problem. This is a function that measures the discrepancy between the data and the model prediction.
prior (nn.Module) – The prior term in the optimization problem. This is a function that encodes prior knowledge about the solution.
acquistion_model (nn.Module) – The acquisition model of the imaging system. This is a function that models the process of data acquisition in the imaging system.
max_iters (int) – The maximum number of iterations for the FISTA algorithm. Defaults to 5.
alpha (float) – The step size for the gradient step. Defaults to 1e-3.
_lambda (float) – The regularization parameter for the prior term. Defaults to 0.1.
- Returns:
None
- forward(y, x0=None, verbose=False)[source]
Runs the FISTA algorithm to solve the optimization problem.
- Parameters:
y (torch.Tensor) – The measurement data to be reconstructed.
x0 (torch.Tensor, optional) – The initial guess for the solution. Defaults to None.
- Returns:
The reconstructed image.
- Return type:
torch.Tensor