.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/demo_does.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_auto_examples_demo_does.py>` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_demo_does.py: Demo DOEs. =================================================== In this example we show how to use a DOE, In progress... .. GENERATED FROM PYTHON SOURCE LINES 11-13 Select Working Directory and Device ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 13-58 .. code-block:: Python import os from torch.utils.data import DataLoader os.chdir(os.path.dirname(os.getcwd())) print("Current Working Directory ", os.getcwd()) import numpy as np from colibri.optics.functional import ( psf_single_doe_spectral, convolutional_sensing, ideal_panchromatic_sensor, ) from colibri.optics.sota_does import ( spiral_doe, conventional_lens, spiral_refractive_index, nbk7_refractive_index, ) from colibri.optics import SingleDOESpectral from colibri import seed_everything # General imports import matplotlib.pyplot as plt import torch import os seed_everything() manual_device = "cpu" doe_size = (100, 100) img_size = (200, 200) type_doe = "spiral" # spiral, conventional_lens convolution_domain = "fourier" # signal, fourier type_wave_propagation = "angular_spectrum" # fresnel, angular_spectrum, fraunhofer wavelengths = torch.Tensor([450, 550, 650]) * 1e-9 # Check GPU support print("GPU support: ", torch.cuda.is_available()) if manual_device: device = manual_device else: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. rst-class:: sphx-glr-script-out .. code-block:: none Current Working Directory /home/runner/work/pycolibri/pycolibri GPU support: False .. GENERATED FROM PYTHON SOURCE LINES 59-61 Load dataset ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 61-73 .. code-block:: Python from colibri.data.datasets import CustomDataset name = "cifar10" path = "." batch_size = 16 dataset = CustomDataset(name, path) dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) .. GENERATED FROM PYTHON SOURCE LINES 74-76 Visualize dataset ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 76-88 .. code-block:: Python from torchvision.utils import make_grid import torchvision sample = next(iter(dataset_loader))["input"] img = make_grid(sample[:32], nrow=8, padding=1, normalize=True, scale_each=False, pad_value=0) plt.figure(figsize=(10, 10)) plt.imshow(img.permute(1, 2, 0)) plt.title("CIFAR10 dataset") plt.axis("off") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_demo_does_001.png :alt: CIFAR10 dataset :srcset: /auto_examples/images/sphx_glr_demo_does_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 89-93 Optics forward model ----------------------------------------------- Define the forward operators :math:`\mathbf{y} = \mathbf{H}_\phi \mathbf{x}`, in this case, the CASSI and SPC forward models. Each optics model can comptute the forward and backward operators i.e., :math:`\mathbf{y} = \mathbf{H}_\phi \mathbf{x}` and :math:`\mathbf{x} = \mathbf{H}^T_\phi \mathbf{y}`. .. GENERATED FROM PYTHON SOURCE LINES 93-137 .. code-block:: Python # source_distance = np.inf if type_wave_propagation == "fraunhofer": source_distance = np.inf else: source_distance = 1 # meters if type_doe == "spiral": radius_doe = 0.5e-3 sensor_distance = 50e-3 pixel_size = (2 * radius_doe) / np.min(doe_size) height_map, aperture = spiral_doe( M=doe_size[0], N=doe_size[1], number_spirals=3, radius=radius_doe, focal=50e-3, start_w=450e-9, end_w=650e-9, ) refractive_index = spiral_refractive_index else: radius_doe = 0.5e-3 # .0e-3 focal = 50e-3 if source_distance == np.inf: sensor_distance = focal # 0.88 else: sensor_distance = 1 / (1 / (focal) - 1 / (source_distance)) pixel_size = (2 * radius_doe) / np.min(doe_size) height_map, aperture = conventional_lens( M=doe_size[0], N=doe_size[1], focal=focal, radius=radius_doe ) refractive_index = nbk7_refractive_index fig, ax = plt.subplots(1, 2, figsize=(10, 10)) ax[0].imshow(height_map, cmap="viridis") ax[0].set_title("Height map") ax[1].imshow(aperture, cmap="plasma") ax[1].set_title("Aperture map") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_demo_does_002.png :alt: Height map, Aperture map :srcset: /auto_examples/images/sphx_glr_demo_does_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 138-144 Functional API for DOEs ----------------------------------------------- Define the recovery model :math:`\mathbf{x} = \mathcal{G}_\theta( \mathbf{y})`, in this case, a simple U-Net model. You can add you custom model by using the :meth: `build_network` function. Additionally we define the end-to-end model that combines the forward and recovery models. Define the loss function :math:`\mathcal{L}`, and the regularizers :math:`\mathcal{R}` for the forward and recovery models. .. GENERATED FROM PYTHON SOURCE LINES 144-180 .. code-block:: Python psf = psf_single_doe_spectral( height_map=height_map, aperture=aperture, refractive_index=refractive_index, wavelengths=wavelengths, source_distance=source_distance, sensor_distance=sensor_distance, pixel_size=pixel_size, approximation=type_wave_propagation, ) fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(((psf - psf.min()) / (psf.max() - psf.min())).permute(1, 2, 0), cmap="plasma") image = convolutional_sensing(sample, psf, domain=convolution_domain) img = make_grid(image, nrow=4, padding=1, normalize=True, scale_each=False, pad_value=0) plt.figure(figsize=(10, 10)) plt.imshow(img.permute(1, 2, 0)) plt.title("CIFAR10 imaged") plt.axis("off") plt.show() image = ideal_panchromatic_sensor(image) img = make_grid(image, nrow=4, padding=1, normalize=True, scale_each=False, pad_value=0) plt.figure(figsize=(10, 10)) plt.imshow(img.permute(1, 2, 0)) plt.title("CIFAR10 imaged pancromatic sensor") plt.axis("off") plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/images/sphx_glr_demo_does_003.png :alt: demo does :srcset: /auto_examples/images/sphx_glr_demo_does_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_demo_does_004.png :alt: CIFAR10 imaged :srcset: /auto_examples/images/sphx_glr_demo_does_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_demo_does_005.png :alt: CIFAR10 imaged pancromatic sensor :srcset: /auto_examples/images/sphx_glr_demo_does_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 181-184 DOE Class ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 184-225 .. code-block:: Python acquisition_model = SingleDOESpectral( input_shape=sample.shape[1:], height_map=height_map, aperture=aperture, wavelengths=wavelengths, source_distance=source_distance, sensor_distance=sensor_distance, sensor_spectral_sensitivity=lambda x: x, pixel_size=pixel_size, doe_refractive_index=refractive_index, approximation=type_wave_propagation, domain=convolution_domain, trainable=False, ) psf = acquisition_model.get_psf() fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(((psf - psf.min()) / (psf.max() - psf.min())).permute(1, 2, 0), cmap="plasma") ax.set_title("RGB PSF") plt.show() image = acquisition_model(sample) img = make_grid(image, nrow=4, padding=1, normalize=True, scale_each=False, pad_value=0) plt.figure(figsize=(10, 10)) plt.imshow(img.permute(1, 2, 0)) plt.title("CIFAR10 imaged") plt.axis("off") plt.show() image = ideal_panchromatic_sensor(image) img = make_grid(image, nrow=4, padding=1, normalize=True, scale_each=False, pad_value=0) plt.figure(figsize=(10, 10)) plt.imshow(img.permute(1, 2, 0)) plt.title("CIFAR10 imaged pancromatic sensor") plt.axis("off") plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/images/sphx_glr_demo_does_006.png :alt: RGB PSF :srcset: /auto_examples/images/sphx_glr_demo_does_006.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_demo_does_007.png :alt: CIFAR10 imaged :srcset: /auto_examples/images/sphx_glr_demo_does_007.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_demo_does_008.png :alt: CIFAR10 imaged pancromatic sensor :srcset: /auto_examples/images/sphx_glr_demo_does_008.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.560 seconds) .. _sphx_glr_download_auto_examples_demo_does.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_does.ipynb <demo_does.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_does.py <demo_does.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_does.zip <demo_does.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_