Note
Go to the end to download the full example code.
Demo PnP.
Select Working Directory and Device
import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory " , os.getcwd())
import sys
sys.path.append(os.path.join(os.getcwd()))
#General imports
import matplotlib.pyplot as plt
import torch
import os
# Set random seed for reproducibility
torch.manual_seed(0)
manual_device = "cpu"
# Check GPU support
print("GPU support: ", torch.cuda.is_available())
if manual_device:
device = manual_device
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Current Working Directory /home/runner/work/pycolibri/pycolibri
GPU support: False
Load dataset
from colibri.data.datasets import CustomDataset
name = 'cifar10'
path = '.'
batch_size = 1
dataset = CustomDataset(name, path)
acquisition_name = 'spc' # ['spc', 'cassi']
Files already downloaded and verified
Visualize dataset
from torchvision.utils import make_grid
from colibri.recovery.terms.transforms import DCT2D
def normalize(image):
return (image - image.min()) / (image.max() - image.min())
sample = dataset[0]['input']
sample = sample.unsqueeze(0).to(device)
Optics forward model
import math
from colibri.optics import SPC, SD_CASSI, DD_CASSI, C_CASSI
img_size = sample.shape[1:]
acquisition_config = dict(
input_shape = img_size,
)
if acquisition_name == 'spc':
n_measurements = 25**2
n_measurements_sqrt = int(math.sqrt(n_measurements))
acquisition_config['n_measurements'] = n_measurements
acquisition_model = {
'spc': SPC,
'sd_cassi': SD_CASSI,
'dd_cassi': DD_CASSI,
'c_cassi': C_CASSI
}[acquisition_name]
acquisition_model = acquisition_model(**acquisition_config)
y = acquisition_model(sample)
# Reconstruct image
from colibri.recovery.pnp import PnP_ADMM
from colibri.recovery.terms.prior import Sparsity
from colibri.recovery.terms.fidelity import L2
from colibri.recovery.terms.transforms import DCT2D
algo_params = {
'max_iters': 200,
'_lambda': 0.05,
'rho': 0.1,
'alpha': 1e-4,
}
fidelity = L2()
prior = Sparsity(basis='dct')
pnp = PnP_ADMM(acquisition_model, fidelity, prior, **algo_params)
x0 = acquisition_model.forward(y, type_calculation="backward")
x_hat = pnp(y, x0=x0, verbose=True)
basis = DCT2D()
theta = basis.forward(x_hat).detach()
print(x_hat.shape)
plt.figure(figsize=(10,10))
plt.subplot(1,4,1)
plt.title('Reference')
plt.imshow(normalize(sample[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1,4,2)
plt.title('Sparse Representation')
plt.imshow(normalize(abs(theta[0,:,:]).permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
if acquisition_name == 'spc':
y = y.reshape(y.shape[0], -1, n_measurements_sqrt, n_measurements_sqrt)
plt.subplot(1,4,3)
plt.title('Measurement')
plt.imshow(normalize(y[0,:,:].permute(1, 2, 0)), cmap='gray')
plt.xticks([])
plt.yticks([])
plt.subplot(1,4,4)
plt.title('Reconstruction')
x_hat -= x_hat.min()
x_hat /= x_hat.max()
plt.imshow(normalize(x_hat[0,:,:].permute(1, 2, 0).detach().cpu().numpy()), cmap='gray')
plt.xticks([])
plt.yticks([])
Iter: 0 fidelity: 1823.8250732421875
Iter: 1 fidelity: 0.4588412940502167
Iter: 2 fidelity: 0.4521573483943939
Iter: 3 fidelity: 0.4939662516117096
Iter: 4 fidelity: 0.4838429093360901
Iter: 5 fidelity: 0.4674130976200104
Iter: 6 fidelity: 0.47328874468803406
Iter: 7 fidelity: 0.4724617004394531
Iter: 8 fidelity: 0.4942083954811096
Iter: 9 fidelity: 0.5231730341911316
Iter: 10 fidelity: 0.49712416529655457
Iter: 11 fidelity: 0.48921072483062744
Iter: 12 fidelity: 0.4982928931713104
Iter: 13 fidelity: 0.47966650128364563
Iter: 14 fidelity: 0.4922373592853546
Iter: 15 fidelity: 0.5256263017654419
Iter: 16 fidelity: 0.46956440806388855
Iter: 17 fidelity: 0.4411783516407013
Iter: 18 fidelity: 0.4918080270290375
Iter: 19 fidelity: 0.4854259490966797
Iter: 20 fidelity: 0.47752174735069275
Iter: 21 fidelity: 0.4904305636882782
Iter: 22 fidelity: 0.49157047271728516
Iter: 23 fidelity: 0.5041679739952087
Iter: 24 fidelity: 0.45151734352111816
Iter: 25 fidelity: 0.49566733837127686
Iter: 26 fidelity: 0.48823535442352295
Iter: 27 fidelity: 0.4821318984031677
Iter: 28 fidelity: 0.4753139317035675
Iter: 29 fidelity: 0.4719826579093933
Iter: 30 fidelity: 0.467926949262619
Iter: 31 fidelity: 0.4601953625679016
Iter: 32 fidelity: 0.47511589527130127
Iter: 33 fidelity: 0.4775075316429138
Iter: 34 fidelity: 0.48247894644737244
Iter: 35 fidelity: 0.5023109316825867
Iter: 36 fidelity: 0.46881091594696045
Iter: 37 fidelity: 0.464871883392334
Iter: 38 fidelity: 0.49159756302833557
Iter: 39 fidelity: 0.4830591678619385
Iter: 40 fidelity: 0.4641215205192566
Iter: 41 fidelity: 0.4673115015029907
Iter: 42 fidelity: 0.4544020891189575
Iter: 43 fidelity: 0.48190852999687195
Iter: 44 fidelity: 0.46221795678138733
Iter: 45 fidelity: 0.4778132140636444
Iter: 46 fidelity: 0.4770721197128296
Iter: 47 fidelity: 0.4883291721343994
Iter: 48 fidelity: 0.49068543314933777
Iter: 49 fidelity: 0.46687108278274536
Iter: 50 fidelity: 0.44757819175720215
Iter: 51 fidelity: 0.468693345785141
Iter: 52 fidelity: 0.45363548398017883
Iter: 53 fidelity: 0.47714918851852417
Iter: 54 fidelity: 0.4360399544239044
Iter: 55 fidelity: 0.4452935755252838
Iter: 56 fidelity: 0.455309122800827
Iter: 57 fidelity: 0.45123225450515747
Iter: 58 fidelity: 0.5037573575973511
Iter: 59 fidelity: 0.49649378657341003
Iter: 60 fidelity: 0.45265620946884155
Iter: 61 fidelity: 0.4776271879673004
Iter: 62 fidelity: 0.455546110868454
Iter: 63 fidelity: 0.47066882252693176
Iter: 64 fidelity: 0.4680958092212677
Iter: 65 fidelity: 0.5035679340362549
Iter: 66 fidelity: 0.47781145572662354
Iter: 67 fidelity: 0.5034958124160767
Iter: 68 fidelity: 0.46567460894584656
Iter: 69 fidelity: 0.4996751546859741
Iter: 70 fidelity: 0.4713825583457947
Iter: 71 fidelity: 0.5069085359573364
Iter: 72 fidelity: 0.48992982506752014
Iter: 73 fidelity: 0.4663350284099579
Iter: 74 fidelity: 0.4771512448787689
Iter: 75 fidelity: 0.45832350850105286
Iter: 76 fidelity: 0.4587981700897217
Iter: 77 fidelity: 0.465382844209671
Iter: 78 fidelity: 0.44641053676605225
Iter: 79 fidelity: 0.45496219396591187
Iter: 80 fidelity: 0.4926610589027405
Iter: 81 fidelity: 0.490904301404953
Iter: 82 fidelity: 0.49758845567703247
Iter: 83 fidelity: 0.5221526622772217
Iter: 84 fidelity: 0.48811525106430054
Iter: 85 fidelity: 0.5111722946166992
Iter: 86 fidelity: 0.4802142083644867
Iter: 87 fidelity: 0.4678373634815216
Iter: 88 fidelity: 0.4818473160266876
Iter: 89 fidelity: 0.48692023754119873
Iter: 90 fidelity: 0.45147448778152466
Iter: 91 fidelity: 0.4647853374481201
Iter: 92 fidelity: 0.46839869022369385
Iter: 93 fidelity: 0.4599814713001251
Iter: 94 fidelity: 0.473921537399292
Iter: 95 fidelity: 0.4745449721813202
Iter: 96 fidelity: 0.4538070857524872
Iter: 97 fidelity: 0.4550686180591583
Iter: 98 fidelity: 0.4743611514568329
Iter: 99 fidelity: 0.47406086325645447
Iter: 100 fidelity: 0.48384758830070496
Iter: 101 fidelity: 0.4587833881378174
Iter: 102 fidelity: 0.47915005683898926
Iter: 103 fidelity: 0.4679105877876282
Iter: 104 fidelity: 0.4882251024246216
Iter: 105 fidelity: 0.4574156105518341
Iter: 106 fidelity: 0.4814000427722931
Iter: 107 fidelity: 0.4839552640914917
Iter: 108 fidelity: 0.4860583245754242
Iter: 109 fidelity: 0.5005573630332947
Iter: 110 fidelity: 0.4787122905254364
Iter: 111 fidelity: 0.5059583187103271
Iter: 112 fidelity: 0.5059174299240112
Iter: 113 fidelity: 0.49114280939102173
Iter: 114 fidelity: 0.5088287591934204
Iter: 115 fidelity: 0.5050346851348877
Iter: 116 fidelity: 0.4986196756362915
Iter: 117 fidelity: 0.4961117208003998
Iter: 118 fidelity: 0.5008435249328613
Iter: 119 fidelity: 0.49607762694358826
Iter: 120 fidelity: 0.4658684730529785
Iter: 121 fidelity: 0.4853409230709076
Iter: 122 fidelity: 0.47063949704170227
Iter: 123 fidelity: 0.4604876637458801
Iter: 124 fidelity: 0.46362996101379395
Iter: 125 fidelity: 0.4725472927093506
Iter: 126 fidelity: 0.44457167387008667
Iter: 127 fidelity: 0.48953163623809814
Iter: 128 fidelity: 0.45130473375320435
Iter: 129 fidelity: 0.45590975880622864
Iter: 130 fidelity: 0.4291759133338928
Iter: 131 fidelity: 0.4539843201637268
Iter: 132 fidelity: 0.4466753900051117
Iter: 133 fidelity: 0.45587456226348877
Iter: 134 fidelity: 0.4564124047756195
Iter: 135 fidelity: 0.47375592589378357
Iter: 136 fidelity: 0.46691179275512695
Iter: 137 fidelity: 0.46365270018577576
Iter: 138 fidelity: 0.4427866041660309
Iter: 139 fidelity: 0.4629698693752289
Iter: 140 fidelity: 0.48216697573661804
Iter: 141 fidelity: 0.4700765013694763
Iter: 142 fidelity: 0.46772539615631104
Iter: 143 fidelity: 0.483059823513031
Iter: 144 fidelity: 0.4540703594684601
Iter: 145 fidelity: 0.47707268595695496
Iter: 146 fidelity: 0.45440369844436646
Iter: 147 fidelity: 0.46783187985420227
Iter: 148 fidelity: 0.44450879096984863
Iter: 149 fidelity: 0.4660499393939972
Iter: 150 fidelity: 0.47018739581108093
Iter: 151 fidelity: 0.45445260405540466
Iter: 152 fidelity: 0.4523097276687622
Iter: 153 fidelity: 0.44541943073272705
Iter: 154 fidelity: 0.4599485397338867
Iter: 155 fidelity: 0.4572729468345642
Iter: 156 fidelity: 0.44617724418640137
Iter: 157 fidelity: 0.4508606195449829
Iter: 158 fidelity: 0.46234965324401855
Iter: 159 fidelity: 0.456813782453537
Iter: 160 fidelity: 0.4312227666378021
Iter: 161 fidelity: 0.4784868359565735
Iter: 162 fidelity: 0.47421884536743164
Iter: 163 fidelity: 0.4725705087184906
Iter: 164 fidelity: 0.4910188317298889
Iter: 165 fidelity: 0.4808518886566162
Iter: 166 fidelity: 0.46139562129974365
Iter: 167 fidelity: 0.4879647493362427
Iter: 168 fidelity: 0.47251981496810913
Iter: 169 fidelity: 0.47271448373794556
Iter: 170 fidelity: 0.4492793083190918
Iter: 171 fidelity: 0.46626415848731995
Iter: 172 fidelity: 0.46247348189353943
Iter: 173 fidelity: 0.46758657693862915
Iter: 174 fidelity: 0.46457210183143616
Iter: 175 fidelity: 0.4603961408138275
Iter: 176 fidelity: 0.45892202854156494
Iter: 177 fidelity: 0.48949548602104187
Iter: 178 fidelity: 0.49830764532089233
Iter: 179 fidelity: 0.49861496686935425
Iter: 180 fidelity: 0.48783573508262634
Iter: 181 fidelity: 0.47912368178367615
Iter: 182 fidelity: 0.4767632484436035
Iter: 183 fidelity: 0.4843460023403168
Iter: 184 fidelity: 0.49366745352745056
Iter: 185 fidelity: 0.49002256989479065
Iter: 186 fidelity: 0.4791885018348694
Iter: 187 fidelity: 0.47707828879356384
Iter: 188 fidelity: 0.48180022835731506
Iter: 189 fidelity: 0.46015745401382446
Iter: 190 fidelity: 0.4682173728942871
Iter: 191 fidelity: 0.4836132228374481
Iter: 192 fidelity: 0.48343923687934875
Iter: 193 fidelity: 0.4824574589729309
Iter: 194 fidelity: 0.46241557598114014
Iter: 195 fidelity: 0.45795103907585144
Iter: 196 fidelity: 0.43898525834083557
Iter: 197 fidelity: 0.4662410318851471
Iter: 198 fidelity: 0.48330551385879517
Iter: 199 fidelity: 0.4570406973361969
torch.Size([1, 3, 32, 32])
([], [])
Total running time of the script: (0 minutes 7.933 seconds)