Demo PnP.

Select Working Directory and Device

import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory " , os.getcwd())

import sys
sys.path.append(os.path.join(os.getcwd()))

#General imports
import matplotlib.pyplot as plt
import torch
import os

# Set random seed for reproducibility
torch.manual_seed(0)

manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())

if manual_device:
    device = manual_device
else:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory  /home/runner/work/pycolibri/pycolibri
GPU support:  False

Load dataset

from colibri.data.datasets import CustomDataset

name = 'cifar10'
path = '.'
batch_size = 1


dataset = CustomDataset(name, path)

acquisition_name = 'spc' #  ['spc', 'cassi']
Files already downloaded and verified

Visualize dataset

from torchvision.utils import make_grid
from colibri.recovery.terms.transforms import DCT2D

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

sample = dataset[0]['input']
sample = sample.unsqueeze(0).to(device)

Optics forward model

import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI

img_size = sample.shape[1:]

acquisition_config = dict(
    input_shape = img_size,
)

if acquisition_name == 'spc':
    n_measurements  = 25**2
    n_measurements_sqrt = int(math.sqrt(n_measurements))
    acquisition_config['n_measurements'] = n_measurements

acquisition_model = {
    'spc': SPC,
    'sd_cassi': SD_CASSI,
    'dd_cassi': DD_CASSI,
    'c_cassi': C_CASSI
}[acquisition_name]

acquisition_model = acquisition_model(**acquisition_config)

y = acquisition_model(sample)

# Reconstruct image
from colibri.recovery.pnp import PnP_ADMM
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D

algo_params = {
    'max_iters': 200,
    '_lambda': 0.05,
    'rho': 0.1,
    'alpha': 1e-4,
}



fidelity  = L2()
prior     = Sparsity(basis='dct')

pnp = PnP_ADMM(acquisition_model, fidelity, prior, **algo_params)

x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = pnp(y, x0=x0,  verbose=True)

basis = DCT2D()
theta = basis.forward(x_hat).detach()

print(x_hat.shape)

plt.figure(figsize=(10,10))

plt.subplot(1,4,1)
plt.title('Reference')
plt.imshow(normalize(sample[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(1,4,2)
plt.title('Sparse Representation')
plt.imshow(normalize(abs(theta[0,:,:]).permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])


if acquisition_name == 'spc':
    y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)


plt.subplot(1,4,3)
plt.title('Measurement')
plt.imshow(normalize(y[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])

plt.subplot(1,4,4)
plt.title('Reconstruction')
x_hat -= x_hat.min()
x_hat /= x_hat.max()
plt.imshow(normalize(x_hat[0,:,:].permute(1, 2, 0).detach().cpu().numpy()), cmap='gray')
plt.xticks([])
plt.yticks([])
Reference, Sparse Representation, Measurement, Reconstruction
Iter:  0 fidelity:  1823.8250732421875
Iter:  1 fidelity:  0.4588412940502167
Iter:  2 fidelity:  0.4521573483943939
Iter:  3 fidelity:  0.4939662516117096
Iter:  4 fidelity:  0.4838429093360901
Iter:  5 fidelity:  0.4674130976200104
Iter:  6 fidelity:  0.47328874468803406
Iter:  7 fidelity:  0.4724617004394531
Iter:  8 fidelity:  0.4942083954811096
Iter:  9 fidelity:  0.5231730341911316
Iter:  10 fidelity:  0.49712416529655457
Iter:  11 fidelity:  0.48921072483062744
Iter:  12 fidelity:  0.4982928931713104
Iter:  13 fidelity:  0.47966650128364563
Iter:  14 fidelity:  0.4922373592853546
Iter:  15 fidelity:  0.5256263017654419
Iter:  16 fidelity:  0.46956440806388855
Iter:  17 fidelity:  0.4411783516407013
Iter:  18 fidelity:  0.4918080270290375
Iter:  19 fidelity:  0.4854259490966797
Iter:  20 fidelity:  0.47752174735069275
Iter:  21 fidelity:  0.4904305636882782
Iter:  22 fidelity:  0.49157047271728516
Iter:  23 fidelity:  0.5041679739952087
Iter:  24 fidelity:  0.45151734352111816
Iter:  25 fidelity:  0.49566733837127686
Iter:  26 fidelity:  0.48823535442352295
Iter:  27 fidelity:  0.4821318984031677
Iter:  28 fidelity:  0.4753139317035675
Iter:  29 fidelity:  0.4719826579093933
Iter:  30 fidelity:  0.467926949262619
Iter:  31 fidelity:  0.4601953625679016
Iter:  32 fidelity:  0.47511589527130127
Iter:  33 fidelity:  0.4775075316429138
Iter:  34 fidelity:  0.48247894644737244
Iter:  35 fidelity:  0.5023109316825867
Iter:  36 fidelity:  0.46881091594696045
Iter:  37 fidelity:  0.464871883392334
Iter:  38 fidelity:  0.49159756302833557
Iter:  39 fidelity:  0.4830591678619385
Iter:  40 fidelity:  0.4641215205192566
Iter:  41 fidelity:  0.4673115015029907
Iter:  42 fidelity:  0.4544020891189575
Iter:  43 fidelity:  0.48190852999687195
Iter:  44 fidelity:  0.46221795678138733
Iter:  45 fidelity:  0.4778132140636444
Iter:  46 fidelity:  0.4770721197128296
Iter:  47 fidelity:  0.4883291721343994
Iter:  48 fidelity:  0.49068543314933777
Iter:  49 fidelity:  0.46687108278274536
Iter:  50 fidelity:  0.44757819175720215
Iter:  51 fidelity:  0.468693345785141
Iter:  52 fidelity:  0.45363548398017883
Iter:  53 fidelity:  0.47714918851852417
Iter:  54 fidelity:  0.4360399544239044
Iter:  55 fidelity:  0.4452935755252838
Iter:  56 fidelity:  0.455309122800827
Iter:  57 fidelity:  0.45123225450515747
Iter:  58 fidelity:  0.5037573575973511
Iter:  59 fidelity:  0.49649378657341003
Iter:  60 fidelity:  0.45265620946884155
Iter:  61 fidelity:  0.4776271879673004
Iter:  62 fidelity:  0.455546110868454
Iter:  63 fidelity:  0.47066882252693176
Iter:  64 fidelity:  0.4680958092212677
Iter:  65 fidelity:  0.5035679340362549
Iter:  66 fidelity:  0.47781145572662354
Iter:  67 fidelity:  0.5034958124160767
Iter:  68 fidelity:  0.46567460894584656
Iter:  69 fidelity:  0.4996751546859741
Iter:  70 fidelity:  0.4713825583457947
Iter:  71 fidelity:  0.5069085359573364
Iter:  72 fidelity:  0.48992982506752014
Iter:  73 fidelity:  0.4663350284099579
Iter:  74 fidelity:  0.4771512448787689
Iter:  75 fidelity:  0.45832350850105286
Iter:  76 fidelity:  0.4587981700897217
Iter:  77 fidelity:  0.465382844209671
Iter:  78 fidelity:  0.44641053676605225
Iter:  79 fidelity:  0.45496219396591187
Iter:  80 fidelity:  0.4926610589027405
Iter:  81 fidelity:  0.490904301404953
Iter:  82 fidelity:  0.49758845567703247
Iter:  83 fidelity:  0.5221526622772217
Iter:  84 fidelity:  0.48811525106430054
Iter:  85 fidelity:  0.5111722946166992
Iter:  86 fidelity:  0.4802142083644867
Iter:  87 fidelity:  0.4678373634815216
Iter:  88 fidelity:  0.4818473160266876
Iter:  89 fidelity:  0.48692023754119873
Iter:  90 fidelity:  0.45147448778152466
Iter:  91 fidelity:  0.4647853374481201
Iter:  92 fidelity:  0.46839869022369385
Iter:  93 fidelity:  0.4599814713001251
Iter:  94 fidelity:  0.473921537399292
Iter:  95 fidelity:  0.4745449721813202
Iter:  96 fidelity:  0.4538070857524872
Iter:  97 fidelity:  0.4550686180591583
Iter:  98 fidelity:  0.4743611514568329
Iter:  99 fidelity:  0.47406086325645447
Iter:  100 fidelity:  0.48384758830070496
Iter:  101 fidelity:  0.4587833881378174
Iter:  102 fidelity:  0.47915005683898926
Iter:  103 fidelity:  0.4679105877876282
Iter:  104 fidelity:  0.4882251024246216
Iter:  105 fidelity:  0.4574156105518341
Iter:  106 fidelity:  0.4814000427722931
Iter:  107 fidelity:  0.4839552640914917
Iter:  108 fidelity:  0.4860583245754242
Iter:  109 fidelity:  0.5005573630332947
Iter:  110 fidelity:  0.4787122905254364
Iter:  111 fidelity:  0.5059583187103271
Iter:  112 fidelity:  0.5059174299240112
Iter:  113 fidelity:  0.49114280939102173
Iter:  114 fidelity:  0.5088287591934204
Iter:  115 fidelity:  0.5050346851348877
Iter:  116 fidelity:  0.4986196756362915
Iter:  117 fidelity:  0.4961117208003998
Iter:  118 fidelity:  0.5008435249328613
Iter:  119 fidelity:  0.49607762694358826
Iter:  120 fidelity:  0.4658684730529785
Iter:  121 fidelity:  0.4853409230709076
Iter:  122 fidelity:  0.47063949704170227
Iter:  123 fidelity:  0.4604876637458801
Iter:  124 fidelity:  0.46362996101379395
Iter:  125 fidelity:  0.4725472927093506
Iter:  126 fidelity:  0.44457167387008667
Iter:  127 fidelity:  0.48953163623809814
Iter:  128 fidelity:  0.45130473375320435
Iter:  129 fidelity:  0.45590975880622864
Iter:  130 fidelity:  0.4291759133338928
Iter:  131 fidelity:  0.4539843201637268
Iter:  132 fidelity:  0.4466753900051117
Iter:  133 fidelity:  0.45587456226348877
Iter:  134 fidelity:  0.4564124047756195
Iter:  135 fidelity:  0.47375592589378357
Iter:  136 fidelity:  0.46691179275512695
Iter:  137 fidelity:  0.46365270018577576
Iter:  138 fidelity:  0.4427866041660309
Iter:  139 fidelity:  0.4629698693752289
Iter:  140 fidelity:  0.48216697573661804
Iter:  141 fidelity:  0.4700765013694763
Iter:  142 fidelity:  0.46772539615631104
Iter:  143 fidelity:  0.483059823513031
Iter:  144 fidelity:  0.4540703594684601
Iter:  145 fidelity:  0.47707268595695496
Iter:  146 fidelity:  0.45440369844436646
Iter:  147 fidelity:  0.46783187985420227
Iter:  148 fidelity:  0.44450879096984863
Iter:  149 fidelity:  0.4660499393939972
Iter:  150 fidelity:  0.47018739581108093
Iter:  151 fidelity:  0.45445260405540466
Iter:  152 fidelity:  0.4523097276687622
Iter:  153 fidelity:  0.44541943073272705
Iter:  154 fidelity:  0.4599485397338867
Iter:  155 fidelity:  0.4572729468345642
Iter:  156 fidelity:  0.44617724418640137
Iter:  157 fidelity:  0.4508606195449829
Iter:  158 fidelity:  0.46234965324401855
Iter:  159 fidelity:  0.456813782453537
Iter:  160 fidelity:  0.4312227666378021
Iter:  161 fidelity:  0.4784868359565735
Iter:  162 fidelity:  0.47421884536743164
Iter:  163 fidelity:  0.4725705087184906
Iter:  164 fidelity:  0.4910188317298889
Iter:  165 fidelity:  0.4808518886566162
Iter:  166 fidelity:  0.46139562129974365
Iter:  167 fidelity:  0.4879647493362427
Iter:  168 fidelity:  0.47251981496810913
Iter:  169 fidelity:  0.47271448373794556
Iter:  170 fidelity:  0.4492793083190918
Iter:  171 fidelity:  0.46626415848731995
Iter:  172 fidelity:  0.46247348189353943
Iter:  173 fidelity:  0.46758657693862915
Iter:  174 fidelity:  0.46457210183143616
Iter:  175 fidelity:  0.4603961408138275
Iter:  176 fidelity:  0.45892202854156494
Iter:  177 fidelity:  0.48949548602104187
Iter:  178 fidelity:  0.49830764532089233
Iter:  179 fidelity:  0.49861496686935425
Iter:  180 fidelity:  0.48783573508262634
Iter:  181 fidelity:  0.47912368178367615
Iter:  182 fidelity:  0.4767632484436035
Iter:  183 fidelity:  0.4843460023403168
Iter:  184 fidelity:  0.49366745352745056
Iter:  185 fidelity:  0.49002256989479065
Iter:  186 fidelity:  0.4791885018348694
Iter:  187 fidelity:  0.47707828879356384
Iter:  188 fidelity:  0.48180022835731506
Iter:  189 fidelity:  0.46015745401382446
Iter:  190 fidelity:  0.4682173728942871
Iter:  191 fidelity:  0.4836132228374481
Iter:  192 fidelity:  0.48343923687934875
Iter:  193 fidelity:  0.4824574589729309
Iter:  194 fidelity:  0.46241557598114014
Iter:  195 fidelity:  0.45795103907585144
Iter:  196 fidelity:  0.43898525834083557
Iter:  197 fidelity:  0.4662410318851471
Iter:  198 fidelity:  0.48330551385879517
Iter:  199 fidelity:  0.4570406973361969
torch.Size([1, 3, 32, 32])

([], [])

Total running time of the script: (0 minutes 7.933 seconds)

Gallery generated by Sphinx-Gallery