.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/demo_fista.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_demo_fista.py: Demo FISTA. =================================================== .. GENERATED FROM PYTHON SOURCE LINES 8-10 Select Working Directory and Device ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 10-38 .. code-block:: Python import os from torch.utils import data os.chdir(os.path.dirname(os.getcwd())) print("Current Working Directory ", os.getcwd()) import sys sys.path.append(os.path.join(os.getcwd())) # General imports import matplotlib.pyplot as plt import torch import os # Set random seed for reproducibility torch.manual_seed(0) manual_device = "cpu" # Check GPU support print("GPU support: ", torch.cuda.is_available()) if manual_device: device = manual_device else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. rst-class:: sphx-glr-script-out .. code-block:: none Current Working Directory /home/runner/work/pycolibri/pycolibri GPU support: False .. GENERATED FROM PYTHON SOURCE LINES 39-41 Load dataset ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 41-53 .. code-block:: Python from colibri.data.datasets import CustomDataset name = "cifar10" path = "." batch_size = 1 dataset = CustomDataset(name, path) acquisition_name = 'spc' # ['spc', 'cassi'] .. GENERATED FROM PYTHON SOURCE LINES 54-56 Visualize dataset ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 56-61 .. code-block:: Python from torchvision.utils import make_grid sample = dataset[0]["input"] sample = sample.unsqueeze(0).to(device) .. GENERATED FROM PYTHON SOURCE LINES 62-63 Optics forward model .. GENERATED FROM PYTHON SOURCE LINES 63-145 .. code-block:: Python import math from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI img_size = sample.shape[1:] acquisition_config = dict( input_shape=img_size, ) if acquisition_name == "spc": n_measurements = 25**2 n_measurements_sqrt = int(math.sqrt(n_measurements)) acquisition_config["n_measurements"] = n_measurements acquisition_model = {"spc": SPC, "sd_cassi": SD_CASSI, "dd_cassi": DD_CASSI, "c_cassi": C_CASSI}[ acquisition_name ] acquisition_model = acquisition_model(**acquisition_config) y = acquisition_model(sample) # Reconstruct image from colibri.recovery.fista import Fista from colibri.recovery.terms.prior import Sparsity from colibri.recovery.terms.fidelity import L2 from colibri.recovery.terms.transforms import DCT2D algo_params = { "max_iters": 2000, "alpha": 1e-4, "_lambda": 0.01, } fidelity = L2() prior = Sparsity(basis="dct") fista = Fista(acquisition_model, fidelity, prior, **algo_params) x0 = acquisition_model.forward(y, type_calculation="backward") x_hat = fista(y, x0=x0) basis = DCT2D() theta = basis.forward(x_hat).detach() normalize = lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x)) plt.figure(figsize=(10, 10)) plt.subplot(1, 4, 1) plt.title("Reference") plt.imshow(sample[0, :, :].permute(1, 2, 0), cmap="gray") plt.xticks([]) plt.yticks([]) plt.subplot(1, 4, 2) plt.title("Sparse Representation") plt.imshow(abs(normalize(theta[0, :, :])).permute(1, 2, 0), cmap="gray") plt.xticks([]) plt.yticks([]) if acquisition_name == "spc": y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt) plt.subplot(1, 4, 3) plt.title("Measurement") plt.imshow(normalize(y[0, :, :]).permute(1, 2, 0), cmap="gray") plt.xticks([]) plt.yticks([]) plt.subplot(1, 4, 4) plt.title("Reconstruction") plt.imshow(normalize(x_hat[0, :, :]).permute(1, 2, 0).detach().cpu().numpy(), cmap="gray") plt.xticks([]) plt.yticks([]) plt.show() .. image-sg:: /auto_examples/images/sphx_glr_demo_fista_001.png :alt: Reference, Sparse Representation, Measurement, Reconstruction :srcset: /auto_examples/images/sphx_glr_demo_fista_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 9.853 seconds) .. _sphx_glr_download_auto_examples_demo_fista.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_fista.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_fista.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_fista.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_