.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/demo_fast_sdcassi.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_demo_fast_sdcassi.py: Demo Fast SD CASSI. =================================================== .. GENERATED FROM PYTHON SOURCE LINES 8-10 Select Working Directory and Device ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 10-42 .. code-block:: Python import os os.chdir(os.path.dirname(os.getcwd())) print("Current Working Directory " , os.getcwd()) import sys sys.path.append(os.path.join(os.getcwd())) import numpy as np from matplotlib import pyplot as plt from skimage.transform import resize from colibri.data.datasets import CustomDataset from colibri.optics import SD_CASSI from colibri.recovery.solvers.sd_cassi import L2L2SolverSDCASSI #General imports import torch # Set random seed for reproducibility torch.manual_seed(0) manual_device = "cpu" # Check GPU support print("GPU support: ", torch.cuda.is_available()) if manual_device: device = manual_device else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. rst-class:: sphx-glr-script-out .. code-block:: none Current Working Directory /home/runner/work/pycolibri/pycolibri GPU support: False .. GENERATED FROM PYTHON SOURCE LINES 43-45 Parameters ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 45-50 .. code-block:: Python M = 256 N = M L = 3 sample_index = 1 .. GENERATED FROM PYTHON SOURCE LINES 51-53 Data ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 53-64 .. code-block:: Python dataset = CustomDataset('cave', path='data') sample = dataset[sample_index] x = sample['output'].numpy() x = x[::-1, ...] x = torch.from_numpy(resize(x[::int(x.shape[0] / L) + 1], [L, M, N]))[None, ...].float() x = x / x.max() # Normalize the image to [0, 1] print(x.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset already downloaded torch.Size([1, 3, 256, 256]) .. GENERATED FROM PYTHON SOURCE LINES 65-67 Tenssor CASSI ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 67-75 .. code-block:: Python sd_cassi = SD_CASSI((L, M, N)) y = sd_cassi(x) ca = sd_cassi.learnable_optics[0] .. GENERATED FROM PYTHON SOURCE LINES 76-78 Setup and run reconstruction ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 78-156 .. code-block:: Python recons_fn = L2L2SolverSDCASSI(y, sd_cassi) xtilde = torch.zeros_like(x, device=x.device) # initial guess rho = 1e-4 # regularization parameter # solution with the solver recons_img = recons_fn.solve(xtilde, rho) # solution with the manual inversion [Qinv, rhoIQ] = recons_fn.ComputeQinv(rho) recons_img2 = recons_fn.IMVMS(sd_cassi(y, type_calculation="backward"), Q=Qinv) error = torch.norm(recons_img - recons_img2) # compute the error fig, ax = plt.subplots(2, 2, figsize=(6, 6)) plt.suptitle(f'Inversion error norm2(recon_img1 - recon_img2): {error:.4e}', fontsize=12) ax[0, 0].imshow(x[0].permute(1, 2, 0) / x.max() ) ax[0, 0].set_title("Original Image") ax[0, 0].axis("off") ax[0, 1].imshow(y[0].permute(1, 2, 0)) ax[0, 1].set_title("Measurement") ax[0, 1].axis("off") ax[1, 0].imshow(recons_img[0].permute(1, 2, 0) / recons_img.max() ) ax[1, 0].set_title("Reconstructed Image 1") ax[1, 0].axis("off") ax[1, 1].imshow(recons_img2[0].permute(1, 2, 0) / recons_img2.max() ) ax[1, 1].set_title("Reconstructed Image 2") ax[1, 1].axis("off") plt.tight_layout() plt.show() from colibri.recovery.terms.fidelity import L2 from colibri.recovery.terms.prior import Sparsity from colibri.recovery.pnp import PnP_ADMM fidelity = L2() prior = Sparsity(basis="dct") algo_params = { 'max_iters': 10, '_lambda': 0.1, 'rho': 0.01, 'alpha': 1e-4, } pnp = PnP_ADMM(sd_cassi, fidelity, prior, **algo_params) x0 = recons_img x_hat = pnp(y, x0=x0, verbose=False).detach() fig, ax = plt.subplots(1, 4, figsize=(20, 5)) plt.subplot(1, 4, 1) plt.title("Measurement") plt.imshow(y[0].permute(1, 2, 0)) plt.axis("off") plt.subplot(1, 4, 2) plt.title("Fast Inversion") plt.imshow(recons_img[0].permute(1, 2, 0) / recons_img.max()) plt.axis("off") plt.subplot(1, 4, 3) plt.title("PnP ADMM") plt.imshow(x_hat[0].permute(1, 2, 0) / x_hat.max()) plt.axis("off") plt.subplot(1, 4, 4) plt.title("Reference") plt.imshow(x[0].permute(1, 2, 0) / x.max()) plt.axis("off") plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/images/sphx_glr_demo_fast_sdcassi_001.png :alt: Inversion error norm2(recon_img1 - recon_img2): 6.7183e-02, Original Image, Measurement, Reconstructed Image 1, Reconstructed Image 2 :srcset: /auto_examples/images/sphx_glr_demo_fast_sdcassi_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_demo_fast_sdcassi_002.png :alt: Measurement, Fast Inversion, PnP ADMM, Reference :srcset: /auto_examples/images/sphx_glr_demo_fast_sdcassi_002.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.014348558..1.0]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.12033227..1.0]. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.981 seconds) .. _sphx_glr_download_auto_examples_demo_fast_sdcassi.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_fast_sdcassi.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_fast_sdcassi.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_fast_sdcassi.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_