Demo PnP.

Select Working Directory and Device

import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory " , os.getcwd())

import sys
sys.path.append(os.path.join(os.getcwd()))

#General imports
import matplotlib.pyplot as plt
import torch
import os

# Set random seed for reproducibility
torch.manual_seed(0)

manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())

if manual_device:
    device = manual_device
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory  /home/runner/work/pycolibri/pycolibri
GPU support:  False

Load dataset

from colibri.data.datasets import CustomDataset

name = 'cifar10'
path = '.'
batch_size = 1


dataset = CustomDataset(name, path)

acquisition_name = 'spc' #  ['spc', 'cassi']

Visualize dataset

from torchvision.utils import make_grid
from colibri.recovery.terms.transforms import DCT2D

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

sample = dataset[0]['input']
sample = sample.unsqueeze(0).to(device)

Optics forward model

import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI

img_size = sample.shape[1:]

acquisition_config = dict(
    input_shape = img_size,
)

if acquisition_name == 'spc':
    n_measurements  = 25**2
    n_measurements_sqrt = int(math.sqrt(n_measurements))
    acquisition_config['n_measurements'] = n_measurements

acquisition_model = {
    'spc': SPC,
    'sd_cassi': SD_CASSI,
    'dd_cassi': DD_CASSI,
    'c_cassi': C_CASSI
}[acquisition_name]

acquisition_model = acquisition_model(**acquisition_config)

y = acquisition_model(sample)

# Reconstruct image
from colibri.recovery.pnp import PnP_ADMM
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D

algo_params = {
    'max_iters': 200,
    '_lambda': 0.05,
    'rho': 0.1,
    'alpha': 1e-4,
}



fidelity  = L2()
prior     = Sparsity(basis='dct')

pnp = PnP_ADMM(acquisition_model, fidelity, prior, **algo_params)

x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = pnp(y, x0=x0,  verbose=True)

basis = DCT2D()
theta = basis.forward(x_hat).detach()

print(x_hat.shape)

plt.figure(figsize=(10,10))

plt.subplot(1,4,1)
plt.title('Reference')
plt.imshow(normalize(sample[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(1,4,2)
plt.title('Sparse Representation')
plt.imshow(normalize(abs(theta[0,:,:]).permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])


if acquisition_name == 'spc':
    y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)


plt.subplot(1,4,3)
plt.title('Measurement')
plt.imshow(normalize(y[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(1,4,4)
plt.title('Reconstruction')
x_hat -= x_hat.min()
x_hat /= x_hat.max()
plt.imshow(normalize(x_hat[0,:,:].permute(1, 2, 0).detach().cpu().numpy()), cmap='gray')
plt.xticks([])
plt.yticks([])
Reference, Sparse Representation, Measurement, Reconstruction
Iter:  0 fidelity:  1763.8348388671875
Iter:  1 fidelity:  0.4410426616668701
Iter:  2 fidelity:  0.4588572680950165
Iter:  3 fidelity:  0.4156217575073242
Iter:  4 fidelity:  0.4227311313152313
Iter:  5 fidelity:  0.4200991988182068
Iter:  6 fidelity:  0.42072609066963196
Iter:  7 fidelity:  0.4392266869544983
Iter:  8 fidelity:  0.43167391419410706
Iter:  9 fidelity:  0.443916916847229
Iter:  10 fidelity:  0.4492640495300293
Iter:  11 fidelity:  0.44792377948760986
Iter:  12 fidelity:  0.44789060950279236
Iter:  13 fidelity:  0.4270676374435425
Iter:  14 fidelity:  0.4409791827201843
Iter:  15 fidelity:  0.41702353954315186
Iter:  16 fidelity:  0.45346349477767944
Iter:  17 fidelity:  0.45864659547805786
Iter:  18 fidelity:  0.4631957709789276
Iter:  19 fidelity:  0.4573199152946472
Iter:  20 fidelity:  0.45692211389541626
Iter:  21 fidelity:  0.43396034836769104
Iter:  22 fidelity:  0.4407581388950348
Iter:  23 fidelity:  0.442766010761261
Iter:  24 fidelity:  0.43382853269577026
Iter:  25 fidelity:  0.4617076516151428
Iter:  26 fidelity:  0.4424895942211151
Iter:  27 fidelity:  0.4607693552970886
Iter:  28 fidelity:  0.45213210582733154
Iter:  29 fidelity:  0.4115031957626343
Iter:  30 fidelity:  0.4481014013290405
Iter:  31 fidelity:  0.4597848057746887
Iter:  32 fidelity:  0.45216798782348633
Iter:  33 fidelity:  0.4456397593021393
Iter:  34 fidelity:  0.464450478553772
Iter:  35 fidelity:  0.417095810174942
Iter:  36 fidelity:  0.4383118748664856
Iter:  37 fidelity:  0.4094241261482239
Iter:  38 fidelity:  0.4626232087612152
Iter:  39 fidelity:  0.41598060727119446
Iter:  40 fidelity:  0.43267130851745605
Iter:  41 fidelity:  0.43120911717414856
Iter:  42 fidelity:  0.4084026515483856
Iter:  43 fidelity:  0.4401206970214844
Iter:  44 fidelity:  0.4154187738895416
Iter:  45 fidelity:  0.410198837518692
Iter:  46 fidelity:  0.4283309280872345
Iter:  47 fidelity:  0.4405807554721832
Iter:  48 fidelity:  0.44300687313079834
Iter:  49 fidelity:  0.45959553122520447
Iter:  50 fidelity:  0.4079134166240692
Iter:  51 fidelity:  0.4227355122566223
Iter:  52 fidelity:  0.4610675275325775
Iter:  53 fidelity:  0.4408958852291107
Iter:  54 fidelity:  0.4623047411441803
Iter:  55 fidelity:  0.43309691548347473
Iter:  56 fidelity:  0.4552910029888153
Iter:  57 fidelity:  0.4770914316177368
Iter:  58 fidelity:  0.438189834356308
Iter:  59 fidelity:  0.4409826397895813
Iter:  60 fidelity:  0.43765175342559814
Iter:  61 fidelity:  0.4514778256416321
Iter:  62 fidelity:  0.41658055782318115
Iter:  63 fidelity:  0.43532243371009827
Iter:  64 fidelity:  0.43618375062942505
Iter:  65 fidelity:  0.4785388112068176
Iter:  66 fidelity:  0.4568840265274048
Iter:  67 fidelity:  0.4581933915615082
Iter:  68 fidelity:  0.4420369267463684
Iter:  69 fidelity:  0.45786252617836
Iter:  70 fidelity:  0.4434940814971924
Iter:  71 fidelity:  0.44693881273269653
Iter:  72 fidelity:  0.445995956659317
Iter:  73 fidelity:  0.4503181576728821
Iter:  74 fidelity:  0.4167283773422241
Iter:  75 fidelity:  0.42678728699684143
Iter:  76 fidelity:  0.4206433594226837
Iter:  77 fidelity:  0.417487233877182
Iter:  78 fidelity:  0.43929409980773926
Iter:  79 fidelity:  0.42145130038261414
Iter:  80 fidelity:  0.44566354155540466
Iter:  81 fidelity:  0.4378133714199066
Iter:  82 fidelity:  0.41719555854797363
Iter:  83 fidelity:  0.4529130458831787
Iter:  84 fidelity:  0.44688382744789124
Iter:  85 fidelity:  0.45443060994148254
Iter:  86 fidelity:  0.43581724166870117
Iter:  87 fidelity:  0.44248852133750916
Iter:  88 fidelity:  0.42302802205085754
Iter:  89 fidelity:  0.4380691945552826
Iter:  90 fidelity:  0.4347599446773529
Iter:  91 fidelity:  0.45021796226501465
Iter:  92 fidelity:  0.45276498794555664
Iter:  93 fidelity:  0.45549991726875305
Iter:  94 fidelity:  0.4577929377555847
Iter:  95 fidelity:  0.4450231194496155
Iter:  96 fidelity:  0.4610012471675873
Iter:  97 fidelity:  0.46698421239852905
Iter:  98 fidelity:  0.4670000970363617
Iter:  99 fidelity:  0.4281656742095947
Iter:  100 fidelity:  0.41380831599235535
Iter:  101 fidelity:  0.4254639446735382
Iter:  102 fidelity:  0.439962774515152
Iter:  103 fidelity:  0.4092889726161957
Iter:  104 fidelity:  0.4455636143684387
Iter:  105 fidelity:  0.4487783908843994
Iter:  106 fidelity:  0.464395135641098
Iter:  107 fidelity:  0.48244065046310425
Iter:  108 fidelity:  0.4547707140445709
Iter:  109 fidelity:  0.43435025215148926
Iter:  110 fidelity:  0.42674490809440613
Iter:  111 fidelity:  0.4303472638130188
Iter:  112 fidelity:  0.4349454939365387
Iter:  113 fidelity:  0.41352716088294983
Iter:  114 fidelity:  0.39870038628578186
Iter:  115 fidelity:  0.44299522042274475
Iter:  116 fidelity:  0.4143761396408081
Iter:  117 fidelity:  0.4420458972454071
Iter:  118 fidelity:  0.42148685455322266
Iter:  119 fidelity:  0.4278317987918854
Iter:  120 fidelity:  0.41203874349594116
Iter:  121 fidelity:  0.42008838057518005
Iter:  122 fidelity:  0.42733004689216614
Iter:  123 fidelity:  0.4368440806865692
Iter:  124 fidelity:  0.437879353761673
Iter:  125 fidelity:  0.4771577715873718
Iter:  126 fidelity:  0.4165879487991333
Iter:  127 fidelity:  0.4409368634223938
Iter:  128 fidelity:  0.4390147030353546
Iter:  129 fidelity:  0.46549680829048157
Iter:  130 fidelity:  0.45326948165893555
Iter:  131 fidelity:  0.4526884853839874
Iter:  132 fidelity:  0.45529958605766296
Iter:  133 fidelity:  0.4510740041732788
Iter:  134 fidelity:  0.4367515444755554
Iter:  135 fidelity:  0.4283357262611389
Iter:  136 fidelity:  0.460586279630661
Iter:  137 fidelity:  0.4261094629764557
Iter:  138 fidelity:  0.4477663040161133
Iter:  139 fidelity:  0.42584624886512756
Iter:  140 fidelity:  0.4282936453819275
Iter:  141 fidelity:  0.4165240228176117
Iter:  142 fidelity:  0.42082566022872925
Iter:  143 fidelity:  0.4110695421695709
Iter:  144 fidelity:  0.4335503578186035
Iter:  145 fidelity:  0.41406407952308655
Iter:  146 fidelity:  0.4318605661392212
Iter:  147 fidelity:  0.40235450863838196
Iter:  148 fidelity:  0.44176632165908813
Iter:  149 fidelity:  0.4322507083415985
Iter:  150 fidelity:  0.42798563838005066
Iter:  151 fidelity:  0.41995009779930115
Iter:  152 fidelity:  0.43623024225234985
Iter:  153 fidelity:  0.44719696044921875
Iter:  154 fidelity:  0.419412761926651
Iter:  155 fidelity:  0.428877055644989
Iter:  156 fidelity:  0.42070186138153076
Iter:  157 fidelity:  0.40830478072166443
Iter:  158 fidelity:  0.44546905159950256
Iter:  159 fidelity:  0.44580215215682983
Iter:  160 fidelity:  0.43814077973365784
Iter:  161 fidelity:  0.4220871925354004
Iter:  162 fidelity:  0.44385263323783875
Iter:  163 fidelity:  0.47086188197135925
Iter:  164 fidelity:  0.4527910351753235
Iter:  165 fidelity:  0.4622322916984558
Iter:  166 fidelity:  0.4516114890575409
Iter:  167 fidelity:  0.4426148533821106
Iter:  168 fidelity:  0.44719424843788147
Iter:  169 fidelity:  0.4200136959552765
Iter:  170 fidelity:  0.4267568588256836
Iter:  171 fidelity:  0.4320485591888428
Iter:  172 fidelity:  0.4350747764110565
Iter:  173 fidelity:  0.41372567415237427
Iter:  174 fidelity:  0.43268778920173645
Iter:  175 fidelity:  0.43483689427375793
Iter:  176 fidelity:  0.41189542412757874
Iter:  177 fidelity:  0.4129372239112854
Iter:  178 fidelity:  0.40711039304733276
Iter:  179 fidelity:  0.4094131290912628
Iter:  180 fidelity:  0.4145738184452057
Iter:  181 fidelity:  0.4390333890914917
Iter:  182 fidelity:  0.4160551428794861
Iter:  183 fidelity:  0.41116413474082947
Iter:  184 fidelity:  0.43597856163978577
Iter:  185 fidelity:  0.433633416891098
Iter:  186 fidelity:  0.42772671580314636
Iter:  187 fidelity:  0.4350983500480652
Iter:  188 fidelity:  0.427231103181839
Iter:  189 fidelity:  0.4080568850040436
Iter:  190 fidelity:  0.4173363149166107
Iter:  191 fidelity:  0.4523578882217407
Iter:  192 fidelity:  0.41230273246765137
Iter:  193 fidelity:  0.41046687960624695
Iter:  194 fidelity:  0.4156299829483032
Iter:  195 fidelity:  0.4250345230102539
Iter:  196 fidelity:  0.4185555875301361
Iter:  197 fidelity:  0.42614153027534485
Iter:  198 fidelity:  0.4467231333255768
Iter:  199 fidelity:  0.41497907042503357
torch.Size([1, 3, 32, 32])

([], [])

Total running time of the script: (0 minutes 8.442 seconds)

Gallery generated by Sphinx-Gallery