Demo FISTA.

Select Working Directory and Device

import os

from torch.utils import data

os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory ", os.getcwd())

import sys

sys.path.append(os.path.join(os.getcwd()))

# General imports
import matplotlib.pyplot as plt
import torch
import os

# Set random seed for reproducibility
torch.manual_seed(0)

manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())

if manual_device:
    device = manual_device
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory  /home/runner/work/pycolibri/pycolibri
GPU support:  False

Load dataset

from colibri.data.datasets import CustomDataset

name = "cifar10"
path = "."
batch_size = 1


dataset = CustomDataset(name, path)


acquisition_name = 'spc'  # ['spc', 'cassi']
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz

  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 851968/170498071 [00:00<00:22, 7589715.86it/s]
  6%|▌         | 10452992/170498071 [00:00<00:02, 56508942.09it/s]
 13%|█▎        | 21987328/170498071 [00:00<00:01, 82378971.80it/s]
 20%|█▉        | 33390592/170498071 [00:00<00:01, 94563624.04it/s]
 26%|██▌       | 44597248/170498071 [00:00<00:01, 100678118.14it/s]
 32%|███▏      | 55345152/170498071 [00:00<00:01, 102882792.50it/s]
 39%|███▉      | 66617344/170498071 [00:00<00:00, 106067498.54it/s]
 46%|████▌     | 77660160/170498071 [00:00<00:00, 107427690.78it/s]
 52%|█████▏    | 89063424/170498071 [00:00<00:00, 109404348.53it/s]
 59%|█████▉    | 100237312/170498071 [00:01<00:00, 110098720.39it/s]
 65%|██████▌   | 111509504/170498071 [00:01<00:00, 110877932.82it/s]
 72%|███████▏  | 122617856/170498071 [00:01<00:00, 110239723.28it/s]
 79%|███████▊  | 134119424/170498071 [00:01<00:00, 111666619.59it/s]
 85%|████████▌ | 145489920/170498071 [00:01<00:00, 112237215.16it/s]
 92%|█████████▏| 156860416/170498071 [00:01<00:00, 112590153.72it/s]
 99%|█████████▊| 168132608/170498071 [00:01<00:00, 111975054.73it/s]
100%|██████████| 170498071/170498071 [00:01<00:00, 103841402.02it/s]
Extracting ./cifar10/cifar-10-python.tar.gz to ./cifar10

Visualize dataset

from torchvision.utils import make_grid

sample = dataset[0]["input"]
sample = sample.unsqueeze(0).to(device)

Optics forward model

import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI


img_size = sample.shape[1:]

acquisition_config = dict(
    input_shape=img_size,
)

if acquisition_name == "spc":
    n_measurements = 25**2
    n_measurements_sqrt = int(math.sqrt(n_measurements))
    acquisition_config["n_measurements"] = n_measurements

acquisition_model = {"spc": SPC, "sd_cassi": SD_CASSI, "dd_cassi": DD_CASSI, "c_cassi": C_CASSI}[
    acquisition_name
]

acquisition_model = acquisition_model(**acquisition_config)

y = acquisition_model(sample)

# Reconstruct image
from colibri.recovery.fista import Fista
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D


algo_params = {
    "max_iters": 2000,
    "alpha": 1e-4,
    "_lambda": 0.01,
}

fidelity = L2()
prior = Sparsity(basis="dct")

fista = Fista(acquisition_model, fidelity, prior, **algo_params)

x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = fista(y, x0=x0)

basis = DCT2D()

theta = basis.forward(x_hat).detach()

normalize = lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x))

plt.figure(figsize=(10, 10))

plt.subplot(1, 4, 1)
plt.title("Reference")
plt.imshow(sample[0, :, :].permute(1, 2, 0), cmap="gray")
plt.xticks([])
plt.yticks([])

plt.subplot(1, 4, 2)
plt.title("Sparse Representation")
plt.imshow(abs(normalize(theta[0, :, :])).permute(1, 2, 0), cmap="gray")
plt.xticks([])
plt.yticks([])

if acquisition_name == "spc":
    y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)

plt.subplot(1, 4, 3)
plt.title("Measurement")
plt.imshow(normalize(y[0, :, :]).permute(1, 2, 0), cmap="gray")
plt.xticks([])
plt.yticks([])

plt.subplot(1, 4, 4)
plt.title("Reconstruction")
plt.imshow(normalize(x_hat[0, :, :]).permute(1, 2, 0).detach().cpu().numpy(), cmap="gray")
plt.xticks([])
plt.yticks([])

plt.show()
Reference, Sparse Representation, Measurement, Reconstruction

Total running time of the script: (0 minutes 13.199 seconds)

Gallery generated by Sphinx-Gallery