Demo PnP.

Select Working Directory and Device

import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory " , os.getcwd())

import sys
sys.path.append(os.path.join(os.getcwd()))

#General imports
import matplotlib.pyplot as plt
import torch
import os

# Set random seed for reproducibility
torch.manual_seed(0)

manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())

if manual_device:
    device = manual_device
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory  /home/runner/work/pycolibri/pycolibri
GPU support:  False

Load dataset

from colibri.data.datasets import CustomDataset

name = 'cifar10'
path = '.'
batch_size = 1


dataset = CustomDataset(name, path)

acquisition_name = 'spc' #  ['spc', 'cassi']

Visualize dataset

from torchvision.utils import make_grid
from colibri.recovery.terms.transforms import DCT2D

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

sample = dataset[0]['input']
sample = sample.unsqueeze(0).to(device)

Optics forward model

import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI

img_size = sample.shape[1:]

acquisition_config = dict(
    input_shape = img_size,
)

if acquisition_name == 'spc':
    n_measurements  = 25**2
    n_measurements_sqrt = int(math.sqrt(n_measurements))
    acquisition_config['n_measurements'] = n_measurements

acquisition_model = {
    'spc': SPC,
    'sd_cassi': SD_CASSI,
    'dd_cassi': DD_CASSI,
    'c_cassi': C_CASSI
}[acquisition_name]

acquisition_model = acquisition_model(**acquisition_config)

y = acquisition_model(sample)

# Reconstruct image
from colibri.recovery.pnp import PnP_ADMM
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D

algo_params = {
    'max_iters': 200,
    '_lambda': 0.05,
    'rho': 0.1,
    'alpha': 1e-4,
}



fidelity  = L2()
prior     = Sparsity(basis='dct')

pnp = PnP_ADMM(acquisition_model, fidelity, prior, **algo_params)

x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = pnp(y, x0=x0,  verbose=True)

basis = DCT2D()
theta = basis.forward(x_hat).detach()

print(x_hat.shape)

plt.figure(figsize=(10,10))

plt.subplot(1,4,1)
plt.title('Reference')
plt.imshow(normalize(sample[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(1,4,2)
plt.title('Sparse Representation')
plt.imshow(normalize(abs(theta[0,:,:]).permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])


if acquisition_name == 'spc':
    y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)


plt.subplot(1,4,3)
plt.title('Measurement')
plt.imshow(normalize(y[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(1,4,4)
plt.title('Reconstruction')
x_hat -= x_hat.min()
x_hat /= x_hat.max()
plt.imshow(normalize(x_hat[0,:,:].permute(1, 2, 0).detach().cpu().numpy()), cmap='gray')
plt.xticks([])
plt.yticks([])
Reference, Sparse Representation, Measurement, Reconstruction
Iter:  0 fidelity:  1823.8250732421875
Iter:  1 fidelity:  0.45847102999687195
Iter:  2 fidelity:  0.44958898425102234
Iter:  3 fidelity:  0.4903907775878906
Iter:  4 fidelity:  0.5036434531211853
Iter:  5 fidelity:  0.46289652585983276
Iter:  6 fidelity:  0.46167850494384766
Iter:  7 fidelity:  0.4766285717487335
Iter:  8 fidelity:  0.5017406344413757
Iter:  9 fidelity:  0.5099577903747559
Iter:  10 fidelity:  0.4726678729057312
Iter:  11 fidelity:  0.45456114411354065
Iter:  12 fidelity:  0.4924970865249634
Iter:  13 fidelity:  0.4805953800678253
Iter:  14 fidelity:  0.5070435404777527
Iter:  15 fidelity:  0.49452561140060425
Iter:  16 fidelity:  0.49283596873283386
Iter:  17 fidelity:  0.4556082487106323
Iter:  18 fidelity:  0.4565475285053253
Iter:  19 fidelity:  0.4709772765636444
Iter:  20 fidelity:  0.47070276737213135
Iter:  21 fidelity:  0.4848296344280243
Iter:  22 fidelity:  0.47601625323295593
Iter:  23 fidelity:  0.48504260182380676
Iter:  24 fidelity:  0.4547821283340454
Iter:  25 fidelity:  0.47771814465522766
Iter:  26 fidelity:  0.49328500032424927
Iter:  27 fidelity:  0.46497905254364014
Iter:  28 fidelity:  0.5056253671646118
Iter:  29 fidelity:  0.4698447585105896
Iter:  30 fidelity:  0.4777454435825348
Iter:  31 fidelity:  0.47850847244262695
Iter:  32 fidelity:  0.48482534289360046
Iter:  33 fidelity:  0.463372141122818
Iter:  34 fidelity:  0.48124295473098755
Iter:  35 fidelity:  0.49649521708488464
Iter:  36 fidelity:  0.49170276522636414
Iter:  37 fidelity:  0.4784253239631653
Iter:  38 fidelity:  0.49354681372642517
Iter:  39 fidelity:  0.48022356629371643
Iter:  40 fidelity:  0.4928903579711914
Iter:  41 fidelity:  0.4767721891403198
Iter:  42 fidelity:  0.4471803903579712
Iter:  43 fidelity:  0.44240549206733704
Iter:  44 fidelity:  0.46999695897102356
Iter:  45 fidelity:  0.4670599699020386
Iter:  46 fidelity:  0.48000529408454895
Iter:  47 fidelity:  0.5143378973007202
Iter:  48 fidelity:  0.4879104495048523
Iter:  49 fidelity:  0.4662438929080963
Iter:  50 fidelity:  0.4857798218727112
Iter:  51 fidelity:  0.4780791997909546
Iter:  52 fidelity:  0.4672298729419708
Iter:  53 fidelity:  0.4526367783546448
Iter:  54 fidelity:  0.469298392534256
Iter:  55 fidelity:  0.4722084701061249
Iter:  56 fidelity:  0.4786020815372467
Iter:  57 fidelity:  0.476763516664505
Iter:  58 fidelity:  0.4987906515598297
Iter:  59 fidelity:  0.49083784222602844
Iter:  60 fidelity:  0.48026806116104126
Iter:  61 fidelity:  0.4669716954231262
Iter:  62 fidelity:  0.4535066783428192
Iter:  63 fidelity:  0.48329734802246094
Iter:  64 fidelity:  0.4722241759300232
Iter:  65 fidelity:  0.45725709199905396
Iter:  66 fidelity:  0.500185489654541
Iter:  67 fidelity:  0.4910682141780853
Iter:  68 fidelity:  0.48815199732780457
Iter:  69 fidelity:  0.48378047347068787
Iter:  70 fidelity:  0.5000890493392944
Iter:  71 fidelity:  0.4904384911060333
Iter:  72 fidelity:  0.4777902066707611
Iter:  73 fidelity:  0.47398048639297485
Iter:  74 fidelity:  0.46570438146591187
Iter:  75 fidelity:  0.4943172037601471
Iter:  76 fidelity:  0.48548564314842224
Iter:  77 fidelity:  0.47611844539642334
Iter:  78 fidelity:  0.45608386397361755
Iter:  79 fidelity:  0.47031277418136597
Iter:  80 fidelity:  0.4956846833229065
Iter:  81 fidelity:  0.5105018615722656
Iter:  82 fidelity:  0.49623218178749084
Iter:  83 fidelity:  0.483551949262619
Iter:  84 fidelity:  0.49874788522720337
Iter:  85 fidelity:  0.48844894766807556
Iter:  86 fidelity:  0.497370183467865
Iter:  87 fidelity:  0.4726472496986389
Iter:  88 fidelity:  0.48355600237846375
Iter:  89 fidelity:  0.47068318724632263
Iter:  90 fidelity:  0.459929883480072
Iter:  91 fidelity:  0.4760059714317322
Iter:  92 fidelity:  0.47184836864471436
Iter:  93 fidelity:  0.44122856855392456
Iter:  94 fidelity:  0.4692312479019165
Iter:  95 fidelity:  0.4644646644592285
Iter:  96 fidelity:  0.4649282991886139
Iter:  97 fidelity:  0.4865356981754303
Iter:  98 fidelity:  0.46691417694091797
Iter:  99 fidelity:  0.4747282564640045
Iter:  100 fidelity:  0.47051581740379333
Iter:  101 fidelity:  0.48752740025520325
Iter:  102 fidelity:  0.4961963891983032
Iter:  103 fidelity:  0.47290244698524475
Iter:  104 fidelity:  0.47549647092819214
Iter:  105 fidelity:  0.485628604888916
Iter:  106 fidelity:  0.48715919256210327
Iter:  107 fidelity:  0.47755470871925354
Iter:  108 fidelity:  0.4773774743080139
Iter:  109 fidelity:  0.45718932151794434
Iter:  110 fidelity:  0.4800483286380768
Iter:  111 fidelity:  0.4842332601547241
Iter:  112 fidelity:  0.47745582461357117
Iter:  113 fidelity:  0.4880981147289276
Iter:  114 fidelity:  0.4857094883918762
Iter:  115 fidelity:  0.4946031868457794
Iter:  116 fidelity:  0.5000946521759033
Iter:  117 fidelity:  0.4806794822216034
Iter:  118 fidelity:  0.4912639260292053
Iter:  119 fidelity:  0.5104355812072754
Iter:  120 fidelity:  0.4718891978263855
Iter:  121 fidelity:  0.464810848236084
Iter:  122 fidelity:  0.4460732638835907
Iter:  123 fidelity:  0.4495674967765808
Iter:  124 fidelity:  0.4825018346309662
Iter:  125 fidelity:  0.4683647155761719
Iter:  126 fidelity:  0.4464244544506073
Iter:  127 fidelity:  0.4625248908996582
Iter:  128 fidelity:  0.47225552797317505
Iter:  129 fidelity:  0.46243324875831604
Iter:  130 fidelity:  0.4596928060054779
Iter:  131 fidelity:  0.4643561542034149
Iter:  132 fidelity:  0.4734286963939667
Iter:  133 fidelity:  0.49313274025917053
Iter:  134 fidelity:  0.4594385027885437
Iter:  135 fidelity:  0.4899147152900696
Iter:  136 fidelity:  0.4753158390522003
Iter:  137 fidelity:  0.4694339334964752
Iter:  138 fidelity:  0.4899340569972992
Iter:  139 fidelity:  0.4980129599571228
Iter:  140 fidelity:  0.4900995194911957
Iter:  141 fidelity:  0.47961893677711487
Iter:  142 fidelity:  0.4924887418746948
Iter:  143 fidelity:  0.4853378236293793
Iter:  144 fidelity:  0.5101111531257629
Iter:  145 fidelity:  0.49907827377319336
Iter:  146 fidelity:  0.4817403554916382
Iter:  147 fidelity:  0.4812775254249573
Iter:  148 fidelity:  0.5018429756164551
Iter:  149 fidelity:  0.501549243927002
Iter:  150 fidelity:  0.492750883102417
Iter:  151 fidelity:  0.500935971736908
Iter:  152 fidelity:  0.486393541097641
Iter:  153 fidelity:  0.4763394296169281
Iter:  154 fidelity:  0.49091339111328125
Iter:  155 fidelity:  0.4723436236381531
Iter:  156 fidelity:  0.486462265253067
Iter:  157 fidelity:  0.48860785365104675
Iter:  158 fidelity:  0.4643516540527344
Iter:  159 fidelity:  0.45436805486679077
Iter:  160 fidelity:  0.4563407003879547
Iter:  161 fidelity:  0.45678722858428955
Iter:  162 fidelity:  0.4561789929866791
Iter:  163 fidelity:  0.48098769783973694
Iter:  164 fidelity:  0.4956667423248291
Iter:  165 fidelity:  0.45661309361457825
Iter:  166 fidelity:  0.4494413435459137
Iter:  167 fidelity:  0.4665078818798065
Iter:  168 fidelity:  0.46128109097480774
Iter:  169 fidelity:  0.45225024223327637
Iter:  170 fidelity:  0.44372695684432983
Iter:  171 fidelity:  0.4654732346534729
Iter:  172 fidelity:  0.4590889513492584
Iter:  173 fidelity:  0.4674101769924164
Iter:  174 fidelity:  0.43849772214889526
Iter:  175 fidelity:  0.46396586298942566
Iter:  176 fidelity:  0.47848421335220337
Iter:  177 fidelity:  0.4718315303325653
Iter:  178 fidelity:  0.4963637888431549
Iter:  179 fidelity:  0.48315590620040894
Iter:  180 fidelity:  0.5001769065856934
Iter:  181 fidelity:  0.4849590063095093
Iter:  182 fidelity:  0.4832480847835541
Iter:  183 fidelity:  0.462694376707077
Iter:  184 fidelity:  0.4576461911201477
Iter:  185 fidelity:  0.4593297243118286
Iter:  186 fidelity:  0.46447476744651794
Iter:  187 fidelity:  0.44857257604599
Iter:  188 fidelity:  0.4642430543899536
Iter:  189 fidelity:  0.47062957286834717
Iter:  190 fidelity:  0.4609135687351227
Iter:  191 fidelity:  0.42410239577293396
Iter:  192 fidelity:  0.43785351514816284
Iter:  193 fidelity:  0.4392968416213989
Iter:  194 fidelity:  0.44521307945251465
Iter:  195 fidelity:  0.4627082645893097
Iter:  196 fidelity:  0.47079896926879883
Iter:  197 fidelity:  0.4448702931404114
Iter:  198 fidelity:  0.4503740966320038
Iter:  199 fidelity:  0.45016589760780334
torch.Size([1, 3, 32, 32])

([], [])

Total running time of the script: (0 minutes 8.020 seconds)

Gallery generated by Sphinx-Gallery