Source code for colibri.recovery.terms.prior

import torch

from .transforms import DCT2D


[docs] class Sparsity(torch.nn.Module): r""" Sparsity prior .. math:: g(\mathbf{x}) = \| \transform \textbf{x}\|_1 where :math:`\transform` is the sparsity basis and :math:`\textbf{x}` is the input tensor. """ def __init__(self, basis=None): r""" Args: basis (str): Basis function. 'dct', 'None'. Default is None. """ super(Sparsity, self).__init__() if basis == 'dct': self.transform = DCT2D() else: self.transform = None
[docs] def forward(self, x): r""" Compute sparsity term. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Sparsity term. """ x = self.transform.forward(x) return torch.norm(x, 1)**2
[docs] def prox(self, x, _lambda, type="soft"): r""" Compute proximal operator of the sparsity term. Args: x (torch.Tensor): Input tensor. _lambda (float): Regularization parameter. type (str): String, it can be "soft" or "hard". Returns: torch.Tensor: Proximal operator of the sparsity term. """ x = x.requires_grad_() x = self.transform.forward(x) if type == 'soft': x = torch.sign(x)*torch.max(torch.abs(x) - _lambda, torch.zeros_like(x)) elif type == 'hard': x = x*(torch.abs(x) > _lambda) x = self.transform.inverse(x) return x
def transform(self, x): if self.transform is not None: return self.transform.forward(x) else: return x def inverse(self, x): if self.transform is not None: return self.transform.inverse(x) else: return x