Demo Filtered Spectral Initialization

Select Working Directory and Device

import os

from torch.utils import data

os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory ", os.getcwd())

import sys

sys.path.append(os.path.join(os.getcwd()))

# General imports
import matplotlib.pyplot as plt
import torch
import os

# Set random seed for reproducibility
torch.manual_seed(0)

manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())

if manual_device:
    device = manual_device
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory  /home/runner/work/pycolibri/pycolibri
GPU support:  False

Load dataset

from colibri.data.datasets import CustomDataset
import torchvision
name = "cifar10"
path = "."
batch_size = 1


dataset = CustomDataset(name, path)
  0%|          | 0.00/170M [00:00<?, ?B/s]
  0%|          | 32.8k/170M [00:00<11:04, 257kB/s]
  0%|          | 229k/170M [00:00<02:49, 1.00MB/s]
  1%|          | 885k/170M [00:00<00:54, 3.11MB/s]
  2%|▏         | 3.44M/170M [00:00<00:15, 11.1MB/s]
  4%|▍         | 6.82M/170M [00:00<00:09, 17.6MB/s]
  8%|▊         | 13.4M/170M [00:00<00:05, 29.1MB/s]
 12%|█▏        | 19.7M/170M [00:00<00:04, 35.3MB/s]
 15%|█▌        | 26.1M/170M [00:00<00:03, 39.1MB/s]
 19%|█▉        | 32.4M/170M [00:01<00:03, 41.6MB/s]
 23%|██▎       | 38.8M/170M [00:01<00:03, 43.1MB/s]
 27%|██▋       | 45.5M/170M [00:01<00:02, 44.6MB/s]
 30%|███       | 51.9M/170M [00:01<00:02, 45.3MB/s]
 34%|███▍      | 58.2M/170M [00:01<00:02, 45.7MB/s]
 38%|███▊      | 64.5M/170M [00:01<00:02, 46.0MB/s]
 42%|████▏     | 71.0M/170M [00:01<00:02, 46.4MB/s]
 45%|████▌     | 77.3M/170M [00:02<00:02, 46.4MB/s]
 49%|████▉     | 84.0M/170M [00:02<00:01, 46.9MB/s]
 53%|█████▎    | 90.3M/170M [00:02<00:01, 46.9MB/s]
 57%|█████▋    | 96.7M/170M [00:02<00:01, 46.8MB/s]
 60%|██████    | 103M/170M [00:02<00:01, 46.7MB/s]
 64%|██████▍   | 109M/170M [00:02<00:01, 46.9MB/s]
 68%|██████▊   | 116M/170M [00:02<00:01, 46.7MB/s]
 72%|███████▏  | 122M/170M [00:03<00:01, 46.7MB/s]
 75%|███████▌  | 128M/170M [00:03<00:00, 46.7MB/s]
 79%|███████▉  | 135M/170M [00:03<00:00, 46.8MB/s]
 83%|████████▎ | 141M/170M [00:03<00:00, 46.9MB/s]
 87%|████████▋ | 148M/170M [00:03<00:00, 46.9MB/s]
 90%|█████████ | 154M/170M [00:03<00:00, 46.9MB/s]
 94%|█████████▍| 161M/170M [00:03<00:00, 46.8MB/s]
 98%|█████████▊| 167M/170M [00:03<00:00, 47.0MB/s]
100%|██████████| 170M/170M [00:04<00:00, 41.9MB/s]

Visualize dataset

from torchvision.utils import make_grid

sample = dataset[0]["input"]
sample = sample.mean(0)
sample = sample.unsqueeze(0).to(device)
sample = torchvision.transforms.Resize((128, 128))(sample)
sample = 1*(sample - torch.min(sample)) / (torch.max(sample) - torch.min(sample))
sample = torch.exp(1j * 2*torch.pi * sample)
sample = torch.nn.functional.pad(sample, (32, 32, 32, 32), mode='constant', value=0)

Optics forward model

from colibri.optics import CodedPhaseImaging
from colibri.optics.functional import coded_phase_imaging_forward, coded_phase_imaging_backward


img_size = sample.shape[1:]



wave_length = 670e-9
pixel_size = 1e-6
sensor_distance = 50e-6
approximation = "fresnel"



acquisition_model = CodedPhaseImaging(
    input_shape=img_size,
    pixel_size=pixel_size,
    wavelength=wave_length,
    sensor_distance=sensor_distance,
    approximation=approximation,
    trainable=False,
)


y = acquisition_model(sample, type_calculation="forward", intensity=True)

Estimate phase

from colibri.recovery import LFSI


lfsi_algorithm = LFSI(
                    max_iters=15,
                    p=0.9,
                    k_size=5,
                    sigma=1.0,
                    train_filter=False,
                    dtype=torch.float32,
                    device=device,
)

x_hat = lfsi_algorithm(y, acquisition_model)


sample = sample.detach().cpu().squeeze().angle()
y = y.detach().cpu().squeeze()
x_hat = x_hat.detach().cpu().squeeze().angle()

normalize = lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x))

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].set_title("Reference")
axs[0].imshow(sample, cmap="gray")
axs[0].set_xticks([])
axs[0].set_yticks([])

axs[1].set_title("Measurement")
axs[1].imshow(y, cmap="gray")
axs[1].set_xticks([])
axs[1].set_yticks([])

axs[2].set_title("Estimation")
axs[2].imshow(x_hat, cmap="gray")
axs[2].set_xticks([])
axs[2].set_yticks([])

plt.show()
Reference, Measurement, Estimation

Total running time of the script: (0 minutes 7.256 seconds)

Gallery generated by Sphinx-Gallery