Note
Go to the end to download the full example code.
Demo PnP.
Select Working Directory and Device
import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory " , os.getcwd())
import sys
sys.path.append(os.path.join(os.getcwd()))
#General imports
import matplotlib.pyplot as plt
import torch
import os
# Set random seed for reproducibility
torch.manual_seed(0)
manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())
if manual_device:
device = manual_device
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory /home/runner/work/pycolibri/pycolibri
GPU support: False
Load dataset
from colibri.data.datasets import CustomDataset
name = 'cifar10'
path = '.'
batch_size = 1
dataset = CustomDataset(name, path)
acquisition_name = 'spc' # ['spc', 'cassi']
Visualize dataset
from torchvision.utils import make_grid
from colibri.recovery.terms.transforms import DCT2D
def normalize(image):
return (image - image.min()) / (image.max() - image.min())
sample = dataset[0]['input']
sample = sample.unsqueeze(0).to(device)
Optics forward model
import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI
img_size = sample.shape[1:]
acquisition_config = dict(
input_shape = img_size,
)
if acquisition_name == 'spc':
n_measurements = 25**2
n_measurements_sqrt = int(math.sqrt(n_measurements))
acquisition_config['n_measurements'] = n_measurements
acquisition_model = {
'spc': SPC,
'sd_cassi': SD_CASSI,
'dd_cassi': DD_CASSI,
'c_cassi': C_CASSI
}[acquisition_name]
acquisition_model = acquisition_model(**acquisition_config)
y = acquisition_model(sample)
# Reconstruct image
from colibri.recovery.pnp import PnP_ADMM
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D
algo_params = {
'max_iters': 200,
'_lambda': 0.05,
'rho': 0.1,
'alpha': 1e-4,
}
fidelity = L2()
prior = Sparsity(basis='dct')
pnp = PnP_ADMM(acquisition_model, fidelity, prior, **algo_params)
x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = pnp(y, x0=x0, verbose=True)
basis = DCT2D()
theta = basis.forward(x_hat).detach()
print(x_hat.shape)
plt.figure(figsize=(10,10))
plt.subplot(1,4,1)
plt.title('Reference')
plt.imshow(normalize(sample[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1,4,2)
plt.title('Sparse Representation')
plt.imshow(normalize(abs(theta[0,:,:]).permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
if acquisition_name == 'spc':
y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)
plt.subplot(1,4,3)
plt.title('Measurement')
plt.imshow(normalize(y[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1,4,4)
plt.title('Reconstruction')
x_hat -= x_hat.min()
x_hat /= x_hat.max()
plt.imshow(normalize(x_hat[0,:,:].permute(1, 2, 0).detach().cpu().numpy()), cmap='gray')
plt.xticks([])
plt.yticks([])

Iter: 0 fidelity: 1823.8250732421875
Iter: 1 fidelity: 0.45847102999687195
Iter: 2 fidelity: 0.44958898425102234
Iter: 3 fidelity: 0.4903907775878906
Iter: 4 fidelity: 0.5036434531211853
Iter: 5 fidelity: 0.46289652585983276
Iter: 6 fidelity: 0.46167850494384766
Iter: 7 fidelity: 0.4766285717487335
Iter: 8 fidelity: 0.5017406344413757
Iter: 9 fidelity: 0.5099577903747559
Iter: 10 fidelity: 0.4726678729057312
Iter: 11 fidelity: 0.45456114411354065
Iter: 12 fidelity: 0.4924970865249634
Iter: 13 fidelity: 0.4805953800678253
Iter: 14 fidelity: 0.5070435404777527
Iter: 15 fidelity: 0.49452561140060425
Iter: 16 fidelity: 0.49283596873283386
Iter: 17 fidelity: 0.4556082487106323
Iter: 18 fidelity: 0.4565475285053253
Iter: 19 fidelity: 0.4709772765636444
Iter: 20 fidelity: 0.47070276737213135
Iter: 21 fidelity: 0.4848296344280243
Iter: 22 fidelity: 0.47601625323295593
Iter: 23 fidelity: 0.48504260182380676
Iter: 24 fidelity: 0.4547821283340454
Iter: 25 fidelity: 0.47771814465522766
Iter: 26 fidelity: 0.49328500032424927
Iter: 27 fidelity: 0.46497905254364014
Iter: 28 fidelity: 0.5056253671646118
Iter: 29 fidelity: 0.4698447585105896
Iter: 30 fidelity: 0.4777454435825348
Iter: 31 fidelity: 0.47850847244262695
Iter: 32 fidelity: 0.48482534289360046
Iter: 33 fidelity: 0.463372141122818
Iter: 34 fidelity: 0.48124295473098755
Iter: 35 fidelity: 0.49649521708488464
Iter: 36 fidelity: 0.49170276522636414
Iter: 37 fidelity: 0.4784253239631653
Iter: 38 fidelity: 0.49354681372642517
Iter: 39 fidelity: 0.48022356629371643
Iter: 40 fidelity: 0.4928903579711914
Iter: 41 fidelity: 0.4767721891403198
Iter: 42 fidelity: 0.4471803903579712
Iter: 43 fidelity: 0.44240549206733704
Iter: 44 fidelity: 0.46999695897102356
Iter: 45 fidelity: 0.4670599699020386
Iter: 46 fidelity: 0.48000529408454895
Iter: 47 fidelity: 0.5143378973007202
Iter: 48 fidelity: 0.4879104495048523
Iter: 49 fidelity: 0.4662438929080963
Iter: 50 fidelity: 0.4857798218727112
Iter: 51 fidelity: 0.4780791997909546
Iter: 52 fidelity: 0.4672298729419708
Iter: 53 fidelity: 0.4526367783546448
Iter: 54 fidelity: 0.469298392534256
Iter: 55 fidelity: 0.4722084701061249
Iter: 56 fidelity: 0.4786020815372467
Iter: 57 fidelity: 0.476763516664505
Iter: 58 fidelity: 0.4987906515598297
Iter: 59 fidelity: 0.49083784222602844
Iter: 60 fidelity: 0.48026806116104126
Iter: 61 fidelity: 0.4669716954231262
Iter: 62 fidelity: 0.4535066783428192
Iter: 63 fidelity: 0.48329734802246094
Iter: 64 fidelity: 0.4722241759300232
Iter: 65 fidelity: 0.45725709199905396
Iter: 66 fidelity: 0.500185489654541
Iter: 67 fidelity: 0.4910682141780853
Iter: 68 fidelity: 0.48815199732780457
Iter: 69 fidelity: 0.48378047347068787
Iter: 70 fidelity: 0.5000890493392944
Iter: 71 fidelity: 0.4904384911060333
Iter: 72 fidelity: 0.4777902066707611
Iter: 73 fidelity: 0.47398048639297485
Iter: 74 fidelity: 0.46570438146591187
Iter: 75 fidelity: 0.4943172037601471
Iter: 76 fidelity: 0.48548564314842224
Iter: 77 fidelity: 0.47611844539642334
Iter: 78 fidelity: 0.45608386397361755
Iter: 79 fidelity: 0.47031277418136597
Iter: 80 fidelity: 0.4956846833229065
Iter: 81 fidelity: 0.5105018615722656
Iter: 82 fidelity: 0.49623218178749084
Iter: 83 fidelity: 0.483551949262619
Iter: 84 fidelity: 0.49874788522720337
Iter: 85 fidelity: 0.48844894766807556
Iter: 86 fidelity: 0.497370183467865
Iter: 87 fidelity: 0.4726472496986389
Iter: 88 fidelity: 0.48355600237846375
Iter: 89 fidelity: 0.47068318724632263
Iter: 90 fidelity: 0.459929883480072
Iter: 91 fidelity: 0.4760059714317322
Iter: 92 fidelity: 0.47184836864471436
Iter: 93 fidelity: 0.44122856855392456
Iter: 94 fidelity: 0.4692312479019165
Iter: 95 fidelity: 0.4644646644592285
Iter: 96 fidelity: 0.4649282991886139
Iter: 97 fidelity: 0.4865356981754303
Iter: 98 fidelity: 0.46691417694091797
Iter: 99 fidelity: 0.4747282564640045
Iter: 100 fidelity: 0.47051581740379333
Iter: 101 fidelity: 0.48752740025520325
Iter: 102 fidelity: 0.4961963891983032
Iter: 103 fidelity: 0.47290244698524475
Iter: 104 fidelity: 0.47549647092819214
Iter: 105 fidelity: 0.485628604888916
Iter: 106 fidelity: 0.48715919256210327
Iter: 107 fidelity: 0.47755470871925354
Iter: 108 fidelity: 0.4773774743080139
Iter: 109 fidelity: 0.45718932151794434
Iter: 110 fidelity: 0.4800483286380768
Iter: 111 fidelity: 0.4842332601547241
Iter: 112 fidelity: 0.47745582461357117
Iter: 113 fidelity: 0.4880981147289276
Iter: 114 fidelity: 0.4857094883918762
Iter: 115 fidelity: 0.4946031868457794
Iter: 116 fidelity: 0.5000946521759033
Iter: 117 fidelity: 0.4806794822216034
Iter: 118 fidelity: 0.4912639260292053
Iter: 119 fidelity: 0.5104355812072754
Iter: 120 fidelity: 0.4718891978263855
Iter: 121 fidelity: 0.464810848236084
Iter: 122 fidelity: 0.4460732638835907
Iter: 123 fidelity: 0.4495674967765808
Iter: 124 fidelity: 0.4825018346309662
Iter: 125 fidelity: 0.4683647155761719
Iter: 126 fidelity: 0.4464244544506073
Iter: 127 fidelity: 0.4625248908996582
Iter: 128 fidelity: 0.47225552797317505
Iter: 129 fidelity: 0.46243324875831604
Iter: 130 fidelity: 0.4596928060054779
Iter: 131 fidelity: 0.4643561542034149
Iter: 132 fidelity: 0.4734286963939667
Iter: 133 fidelity: 0.49313274025917053
Iter: 134 fidelity: 0.4594385027885437
Iter: 135 fidelity: 0.4899147152900696
Iter: 136 fidelity: 0.4753158390522003
Iter: 137 fidelity: 0.4694339334964752
Iter: 138 fidelity: 0.4899340569972992
Iter: 139 fidelity: 0.4980129599571228
Iter: 140 fidelity: 0.4900995194911957
Iter: 141 fidelity: 0.47961893677711487
Iter: 142 fidelity: 0.4924887418746948
Iter: 143 fidelity: 0.4853378236293793
Iter: 144 fidelity: 0.5101111531257629
Iter: 145 fidelity: 0.49907827377319336
Iter: 146 fidelity: 0.4817403554916382
Iter: 147 fidelity: 0.4812775254249573
Iter: 148 fidelity: 0.5018429756164551
Iter: 149 fidelity: 0.501549243927002
Iter: 150 fidelity: 0.492750883102417
Iter: 151 fidelity: 0.500935971736908
Iter: 152 fidelity: 0.486393541097641
Iter: 153 fidelity: 0.4763394296169281
Iter: 154 fidelity: 0.49091339111328125
Iter: 155 fidelity: 0.4723436236381531
Iter: 156 fidelity: 0.486462265253067
Iter: 157 fidelity: 0.48860785365104675
Iter: 158 fidelity: 0.4643516540527344
Iter: 159 fidelity: 0.45436805486679077
Iter: 160 fidelity: 0.4563407003879547
Iter: 161 fidelity: 0.45678722858428955
Iter: 162 fidelity: 0.4561789929866791
Iter: 163 fidelity: 0.48098769783973694
Iter: 164 fidelity: 0.4956667423248291
Iter: 165 fidelity: 0.45661309361457825
Iter: 166 fidelity: 0.4494413435459137
Iter: 167 fidelity: 0.4665078818798065
Iter: 168 fidelity: 0.46128109097480774
Iter: 169 fidelity: 0.45225024223327637
Iter: 170 fidelity: 0.44372695684432983
Iter: 171 fidelity: 0.4654732346534729
Iter: 172 fidelity: 0.4590889513492584
Iter: 173 fidelity: 0.4674101769924164
Iter: 174 fidelity: 0.43849772214889526
Iter: 175 fidelity: 0.46396586298942566
Iter: 176 fidelity: 0.47848421335220337
Iter: 177 fidelity: 0.4718315303325653
Iter: 178 fidelity: 0.4963637888431549
Iter: 179 fidelity: 0.48315590620040894
Iter: 180 fidelity: 0.5001769065856934
Iter: 181 fidelity: 0.4849590063095093
Iter: 182 fidelity: 0.4832480847835541
Iter: 183 fidelity: 0.462694376707077
Iter: 184 fidelity: 0.4576461911201477
Iter: 185 fidelity: 0.4593297243118286
Iter: 186 fidelity: 0.46447476744651794
Iter: 187 fidelity: 0.44857257604599
Iter: 188 fidelity: 0.4642430543899536
Iter: 189 fidelity: 0.47062957286834717
Iter: 190 fidelity: 0.4609135687351227
Iter: 191 fidelity: 0.42410239577293396
Iter: 192 fidelity: 0.43785351514816284
Iter: 193 fidelity: 0.4392968416213989
Iter: 194 fidelity: 0.44521307945251465
Iter: 195 fidelity: 0.4627082645893097
Iter: 196 fidelity: 0.47079896926879883
Iter: 197 fidelity: 0.4448702931404114
Iter: 198 fidelity: 0.4503740966320038
Iter: 199 fidelity: 0.45016589760780334
torch.Size([1, 3, 32, 32])
([], [])
Total running time of the script: (0 minutes 8.020 seconds)