Demo FISTA.

Select Working Directory and Device

import os

from torch.utils import data

os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory ", os.getcwd())

import sys

sys.path.append(os.path.join(os.getcwd()))

# General imports
import matplotlib.pyplot as plt
import torch
import os

# Set random seed for reproducibility
torch.manual_seed(0)

manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())

if manual_device:
    device = manual_device
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory  /home/runner/work/pycolibri/pycolibri
GPU support:  False

Load dataset

from colibri.data.datasets import CustomDataset

name = "cifar10"
path = "."
batch_size = 1


dataset = CustomDataset(name, path)


acquisition_name = 'spc'  # ['spc', 'cassi']
  0%|          | 0.00/170M [00:00<?, ?B/s]
  1%|          | 1.15M/170M [00:00<00:14, 11.4MB/s]
  2%|▏         | 3.93M/170M [00:00<00:08, 20.8MB/s]
  4%|▍         | 7.67M/170M [00:00<00:05, 28.2MB/s]
  6%|▌         | 10.6M/170M [00:00<00:05, 28.6MB/s]
  8%|▊         | 13.5M/170M [00:00<00:05, 28.0MB/s]
 10%|▉         | 16.3M/170M [00:00<00:05, 27.9MB/s]
 11%|█         | 19.1M/170M [00:00<00:05, 28.0MB/s]
 13%|█▎        | 22.0M/170M [00:00<00:05, 27.7MB/s]
 15%|█▍        | 24.8M/170M [00:00<00:05, 27.8MB/s]
 16%|█▌        | 27.6M/170M [00:01<00:05, 27.7MB/s]
 18%|█▊        | 30.3M/170M [00:01<00:05, 27.3MB/s]
 19%|█▉        | 33.2M/170M [00:01<00:05, 27.5MB/s]
 22%|██▏       | 37.6M/170M [00:01<00:04, 32.3MB/s]
 26%|██▋       | 44.9M/170M [00:01<00:02, 44.5MB/s]
 31%|███▏      | 53.5M/170M [00:01<00:02, 57.0MB/s]
 36%|███▌      | 61.7M/170M [00:01<00:01, 64.3MB/s]
 41%|████      | 70.3M/170M [00:01<00:01, 70.7MB/s]
 47%|████▋     | 79.5M/170M [00:01<00:01, 77.0MB/s]
 52%|█████▏    | 88.4M/170M [00:01<00:01, 80.8MB/s]
 57%|█████▋    | 97.6M/170M [00:02<00:00, 83.9MB/s]
 63%|██████▎   | 107M/170M [00:02<00:00, 85.7MB/s]
 68%|██████▊   | 116M/170M [00:02<00:00, 87.4MB/s]
 73%|███████▎  | 125M/170M [00:02<00:00, 85.3MB/s]
 78%|███████▊  | 133M/170M [00:02<00:00, 86.4MB/s]
 84%|████████▎ | 143M/170M [00:02<00:00, 87.7MB/s]
 89%|████████▉ | 151M/170M [00:02<00:00, 87.8MB/s]
 94%|█████████▍| 161M/170M [00:02<00:00, 89.0MB/s]
 99%|█████████▉| 170M/170M [00:02<00:00, 89.4MB/s]
100%|██████████| 170M/170M [00:02<00:00, 59.9MB/s]

Visualize dataset

from torchvision.utils import make_grid

sample = dataset[0]["input"]
sample = sample.unsqueeze(0).to(device)

Optics forward model

import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI


img_size = sample.shape[1:]

acquisition_config = dict(
    input_shape=img_size,
)

if acquisition_name == "spc":
    n_measurements = 25**2
    n_measurements_sqrt = int(math.sqrt(n_measurements))
    acquisition_config["n_measurements"] = n_measurements

acquisition_model = {"spc": SPC, "sd_cassi": SD_CASSI, "dd_cassi": DD_CASSI, "c_cassi": C_CASSI}[
    acquisition_name
]

acquisition_model = acquisition_model(**acquisition_config)

y = acquisition_model(sample)

# Reconstruct image
from colibri.recovery.fista import Fista
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D


algo_params = {
    "max_iters": 2000,
    "alpha": 1e-4,
    "_lambda": 0.01,
}

fidelity = L2()
prior = Sparsity(basis="dct")

fista = Fista(acquisition_model, fidelity, prior, **algo_params)

x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = fista(y, x0=x0)

basis = DCT2D()

theta = basis.forward(x_hat).detach()

normalize = lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x))

plt.figure(figsize=(10, 10))

plt.subplot(1, 4, 1)
plt.title("Reference")
plt.imshow(sample[0, :, :].permute(1, 2, 0), cmap="gray")
plt.xticks([])
plt.yticks([])

plt.subplot(1, 4, 2)
plt.title("Sparse Representation")
plt.imshow(abs(normalize(theta[0, :, :])).permute(1, 2, 0), cmap="gray")
plt.xticks([])
plt.yticks([])

if acquisition_name == "spc":
    y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)

plt.subplot(1, 4, 3)
plt.title("Measurement")
plt.imshow(normalize(y[0, :, :]).permute(1, 2, 0), cmap="gray")
plt.xticks([])
plt.yticks([])

plt.subplot(1, 4, 4)
plt.title("Reconstruction")
plt.imshow(normalize(x_hat[0, :, :]).permute(1, 2, 0).detach().cpu().numpy(), cmap="gray")
plt.xticks([])
plt.yticks([])

plt.show()
Reference, Sparse Representation, Measurement, Reconstruction

Total running time of the script: (0 minutes 14.597 seconds)

Gallery generated by Sphinx-Gallery