Note
Go to the end to download the full example code.
Demo PnP.
Select Working Directory and Device
import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory " , os.getcwd())
import sys
sys.path.append(os.path.join(os.getcwd()))
#General imports
import matplotlib.pyplot as plt
import torch
import os
# Set random seed for reproducibility
torch.manual_seed(0)
manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())
if manual_device:
device = manual_device
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory /home/runner/work/pycolibri/pycolibri
GPU support: False
Load dataset
from colibri.data.datasets import CustomDataset
name = 'cifar10'
path = '.'
batch_size = 1
dataset = CustomDataset(name, path)
acquisition_name = 'spc' # ['spc', 'cassi']
Visualize dataset
from torchvision.utils import make_grid
from colibri.recovery.terms.transforms import DCT2D
def normalize(image):
return (image - image.min()) / (image.max() - image.min())
sample = dataset[0]['input']
sample = sample.unsqueeze(0).to(device)
Optics forward model
import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI
img_size = sample.shape[1:]
acquisition_config = dict(
input_shape = img_size,
)
if acquisition_name == 'spc':
n_measurements = 25**2
n_measurements_sqrt = int(math.sqrt(n_measurements))
acquisition_config['n_measurements'] = n_measurements
acquisition_model = {
'spc': SPC,
'sd_cassi': SD_CASSI,
'dd_cassi': DD_CASSI,
'c_cassi': C_CASSI
}[acquisition_name]
acquisition_model = acquisition_model(**acquisition_config)
y = acquisition_model(sample)
# Reconstruct image
from colibri.recovery.pnp import PnP_ADMM
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D
algo_params = {
'max_iters': 200,
'_lambda': 0.05,
'rho': 0.1,
'alpha': 1e-4,
}
fidelity = L2()
prior = Sparsity(basis='dct')
pnp = PnP_ADMM(acquisition_model, fidelity, prior, **algo_params)
x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = pnp(y, x0=x0, verbose=True)
basis = DCT2D()
theta = basis.forward(x_hat).detach()
print(x_hat.shape)
plt.figure(figsize=(10,10))
plt.subplot(1,4,1)
plt.title('Reference')
plt.imshow(normalize(sample[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1,4,2)
plt.title('Sparse Representation')
plt.imshow(normalize(abs(theta[0,:,:]).permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
if acquisition_name == 'spc':
y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)
plt.subplot(1,4,3)
plt.title('Measurement')
plt.imshow(normalize(y[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1,4,4)
plt.title('Reconstruction')
x_hat -= x_hat.min()
x_hat /= x_hat.max()
plt.imshow(normalize(x_hat[0,:,:].permute(1, 2, 0).detach().cpu().numpy()), cmap='gray')
plt.xticks([])
plt.yticks([])

Iter: 0 fidelity: 1763.8348388671875
Iter: 1 fidelity: 0.4410426616668701
Iter: 2 fidelity: 0.4588572680950165
Iter: 3 fidelity: 0.4156217575073242
Iter: 4 fidelity: 0.4227311313152313
Iter: 5 fidelity: 0.4200991988182068
Iter: 6 fidelity: 0.42072609066963196
Iter: 7 fidelity: 0.4392266869544983
Iter: 8 fidelity: 0.43167391419410706
Iter: 9 fidelity: 0.443916916847229
Iter: 10 fidelity: 0.4492640495300293
Iter: 11 fidelity: 0.44792377948760986
Iter: 12 fidelity: 0.44789060950279236
Iter: 13 fidelity: 0.4270676374435425
Iter: 14 fidelity: 0.4409791827201843
Iter: 15 fidelity: 0.41702353954315186
Iter: 16 fidelity: 0.45346349477767944
Iter: 17 fidelity: 0.45864659547805786
Iter: 18 fidelity: 0.4631957709789276
Iter: 19 fidelity: 0.4573199152946472
Iter: 20 fidelity: 0.45692211389541626
Iter: 21 fidelity: 0.43396034836769104
Iter: 22 fidelity: 0.4407581388950348
Iter: 23 fidelity: 0.442766010761261
Iter: 24 fidelity: 0.43382853269577026
Iter: 25 fidelity: 0.4617076516151428
Iter: 26 fidelity: 0.4424895942211151
Iter: 27 fidelity: 0.4607693552970886
Iter: 28 fidelity: 0.45213210582733154
Iter: 29 fidelity: 0.4115031957626343
Iter: 30 fidelity: 0.4481014013290405
Iter: 31 fidelity: 0.4597848057746887
Iter: 32 fidelity: 0.45216798782348633
Iter: 33 fidelity: 0.4456397593021393
Iter: 34 fidelity: 0.464450478553772
Iter: 35 fidelity: 0.417095810174942
Iter: 36 fidelity: 0.4383118748664856
Iter: 37 fidelity: 0.4094241261482239
Iter: 38 fidelity: 0.4626232087612152
Iter: 39 fidelity: 0.41598060727119446
Iter: 40 fidelity: 0.43267130851745605
Iter: 41 fidelity: 0.43120911717414856
Iter: 42 fidelity: 0.4084026515483856
Iter: 43 fidelity: 0.4401206970214844
Iter: 44 fidelity: 0.4154187738895416
Iter: 45 fidelity: 0.410198837518692
Iter: 46 fidelity: 0.4283309280872345
Iter: 47 fidelity: 0.4405807554721832
Iter: 48 fidelity: 0.44300687313079834
Iter: 49 fidelity: 0.45959553122520447
Iter: 50 fidelity: 0.4079134166240692
Iter: 51 fidelity: 0.4227355122566223
Iter: 52 fidelity: 0.4610675275325775
Iter: 53 fidelity: 0.4408958852291107
Iter: 54 fidelity: 0.4623047411441803
Iter: 55 fidelity: 0.43309691548347473
Iter: 56 fidelity: 0.4552910029888153
Iter: 57 fidelity: 0.4770914316177368
Iter: 58 fidelity: 0.438189834356308
Iter: 59 fidelity: 0.4409826397895813
Iter: 60 fidelity: 0.43765175342559814
Iter: 61 fidelity: 0.4514778256416321
Iter: 62 fidelity: 0.41658055782318115
Iter: 63 fidelity: 0.43532243371009827
Iter: 64 fidelity: 0.43618375062942505
Iter: 65 fidelity: 0.4785388112068176
Iter: 66 fidelity: 0.4568840265274048
Iter: 67 fidelity: 0.4581933915615082
Iter: 68 fidelity: 0.4420369267463684
Iter: 69 fidelity: 0.45786252617836
Iter: 70 fidelity: 0.4434940814971924
Iter: 71 fidelity: 0.44693881273269653
Iter: 72 fidelity: 0.445995956659317
Iter: 73 fidelity: 0.4503181576728821
Iter: 74 fidelity: 0.4167283773422241
Iter: 75 fidelity: 0.42678728699684143
Iter: 76 fidelity: 0.4206433594226837
Iter: 77 fidelity: 0.417487233877182
Iter: 78 fidelity: 0.43929409980773926
Iter: 79 fidelity: 0.42145130038261414
Iter: 80 fidelity: 0.44566354155540466
Iter: 81 fidelity: 0.4378133714199066
Iter: 82 fidelity: 0.41719555854797363
Iter: 83 fidelity: 0.4529130458831787
Iter: 84 fidelity: 0.44688382744789124
Iter: 85 fidelity: 0.45443060994148254
Iter: 86 fidelity: 0.43581724166870117
Iter: 87 fidelity: 0.44248852133750916
Iter: 88 fidelity: 0.42302802205085754
Iter: 89 fidelity: 0.4380691945552826
Iter: 90 fidelity: 0.4347599446773529
Iter: 91 fidelity: 0.45021796226501465
Iter: 92 fidelity: 0.45276498794555664
Iter: 93 fidelity: 0.45549991726875305
Iter: 94 fidelity: 0.4577929377555847
Iter: 95 fidelity: 0.4450231194496155
Iter: 96 fidelity: 0.4610012471675873
Iter: 97 fidelity: 0.46698421239852905
Iter: 98 fidelity: 0.4670000970363617
Iter: 99 fidelity: 0.4281656742095947
Iter: 100 fidelity: 0.41380831599235535
Iter: 101 fidelity: 0.4254639446735382
Iter: 102 fidelity: 0.439962774515152
Iter: 103 fidelity: 0.4092889726161957
Iter: 104 fidelity: 0.4455636143684387
Iter: 105 fidelity: 0.4487783908843994
Iter: 106 fidelity: 0.464395135641098
Iter: 107 fidelity: 0.48244065046310425
Iter: 108 fidelity: 0.4547707140445709
Iter: 109 fidelity: 0.43435025215148926
Iter: 110 fidelity: 0.42674490809440613
Iter: 111 fidelity: 0.4303472638130188
Iter: 112 fidelity: 0.4349454939365387
Iter: 113 fidelity: 0.41352716088294983
Iter: 114 fidelity: 0.39870038628578186
Iter: 115 fidelity: 0.44299522042274475
Iter: 116 fidelity: 0.4143761396408081
Iter: 117 fidelity: 0.4420458972454071
Iter: 118 fidelity: 0.42148685455322266
Iter: 119 fidelity: 0.4278317987918854
Iter: 120 fidelity: 0.41203874349594116
Iter: 121 fidelity: 0.42008838057518005
Iter: 122 fidelity: 0.42733004689216614
Iter: 123 fidelity: 0.4368440806865692
Iter: 124 fidelity: 0.437879353761673
Iter: 125 fidelity: 0.4771577715873718
Iter: 126 fidelity: 0.4165879487991333
Iter: 127 fidelity: 0.4409368634223938
Iter: 128 fidelity: 0.4390147030353546
Iter: 129 fidelity: 0.46549680829048157
Iter: 130 fidelity: 0.45326948165893555
Iter: 131 fidelity: 0.4526884853839874
Iter: 132 fidelity: 0.45529958605766296
Iter: 133 fidelity: 0.4510740041732788
Iter: 134 fidelity: 0.4367515444755554
Iter: 135 fidelity: 0.4283357262611389
Iter: 136 fidelity: 0.460586279630661
Iter: 137 fidelity: 0.4261094629764557
Iter: 138 fidelity: 0.4477663040161133
Iter: 139 fidelity: 0.42584624886512756
Iter: 140 fidelity: 0.4282936453819275
Iter: 141 fidelity: 0.4165240228176117
Iter: 142 fidelity: 0.42082566022872925
Iter: 143 fidelity: 0.4110695421695709
Iter: 144 fidelity: 0.4335503578186035
Iter: 145 fidelity: 0.41406407952308655
Iter: 146 fidelity: 0.4318605661392212
Iter: 147 fidelity: 0.40235450863838196
Iter: 148 fidelity: 0.44176632165908813
Iter: 149 fidelity: 0.4322507083415985
Iter: 150 fidelity: 0.42798563838005066
Iter: 151 fidelity: 0.41995009779930115
Iter: 152 fidelity: 0.43623024225234985
Iter: 153 fidelity: 0.44719696044921875
Iter: 154 fidelity: 0.419412761926651
Iter: 155 fidelity: 0.428877055644989
Iter: 156 fidelity: 0.42070186138153076
Iter: 157 fidelity: 0.40830478072166443
Iter: 158 fidelity: 0.44546905159950256
Iter: 159 fidelity: 0.44580215215682983
Iter: 160 fidelity: 0.43814077973365784
Iter: 161 fidelity: 0.4220871925354004
Iter: 162 fidelity: 0.44385263323783875
Iter: 163 fidelity: 0.47086188197135925
Iter: 164 fidelity: 0.4527910351753235
Iter: 165 fidelity: 0.4622322916984558
Iter: 166 fidelity: 0.4516114890575409
Iter: 167 fidelity: 0.4426148533821106
Iter: 168 fidelity: 0.44719424843788147
Iter: 169 fidelity: 0.4200136959552765
Iter: 170 fidelity: 0.4267568588256836
Iter: 171 fidelity: 0.4320485591888428
Iter: 172 fidelity: 0.4350747764110565
Iter: 173 fidelity: 0.41372567415237427
Iter: 174 fidelity: 0.43268778920173645
Iter: 175 fidelity: 0.43483689427375793
Iter: 176 fidelity: 0.41189542412757874
Iter: 177 fidelity: 0.4129372239112854
Iter: 178 fidelity: 0.40711039304733276
Iter: 179 fidelity: 0.4094131290912628
Iter: 180 fidelity: 0.4145738184452057
Iter: 181 fidelity: 0.4390333890914917
Iter: 182 fidelity: 0.4160551428794861
Iter: 183 fidelity: 0.41116413474082947
Iter: 184 fidelity: 0.43597856163978577
Iter: 185 fidelity: 0.433633416891098
Iter: 186 fidelity: 0.42772671580314636
Iter: 187 fidelity: 0.4350983500480652
Iter: 188 fidelity: 0.427231103181839
Iter: 189 fidelity: 0.4080568850040436
Iter: 190 fidelity: 0.4173363149166107
Iter: 191 fidelity: 0.4523578882217407
Iter: 192 fidelity: 0.41230273246765137
Iter: 193 fidelity: 0.41046687960624695
Iter: 194 fidelity: 0.4156299829483032
Iter: 195 fidelity: 0.4250345230102539
Iter: 196 fidelity: 0.4185555875301361
Iter: 197 fidelity: 0.42614153027534485
Iter: 198 fidelity: 0.4467231333255768
Iter: 199 fidelity: 0.41497907042503357
torch.Size([1, 3, 32, 32])
([], [])
Total running time of the script: (0 minutes 8.442 seconds)