Note
Go to the end to download the full example code.
Demo Datasets.
In this example we show how to use the custom dataset class to load the predefined datasets in the repository.
Select Working Directory and Device
import os
from random import randint
from torch.utils.data import DataLoader
os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory ", os.getcwd())
Current Working Directory /home/runner/work/pycolibri/pycolibri
Load builtin dataset
from colibri.data.datasets import CustomDataset
name = 'cifar10' # ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'cave']
path = 'data'
batch_size = 128
dataset = CustomDataset(name, path)
dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz
0%| | 0/170498071 [00:00<?, ?it/s]
1%| | 917504/170498071 [00:00<00:20, 8478260.01it/s]
6%|▋ | 10682368/170498071 [00:00<00:02, 59185127.29it/s]
13%|█▎ | 22052864/170498071 [00:00<00:01, 83567336.05it/s]
20%|█▉ | 33488896/170498071 [00:00<00:01, 95575958.55it/s]
26%|██▋ | 44859392/170498071 [00:00<00:01, 101985501.50it/s]
33%|███▎ | 56262656/170498071 [00:00<00:01, 106038001.56it/s]
40%|███▉ | 67633152/170498071 [00:00<00:00, 108489826.34it/s]
46%|████▋ | 79036416/170498071 [00:00<00:00, 110160458.04it/s]
53%|█████▎ | 90177536/170498071 [00:00<00:00, 110414727.23it/s]
60%|█████▉ | 101613568/170498071 [00:01<00:00, 111494616.91it/s]
66%|██████▋ | 113049600/170498071 [00:01<00:00, 112288315.66it/s]
73%|███████▎ | 124420096/170498071 [00:01<00:00, 112633534.81it/s]
80%|███████▉ | 135692288/170498071 [00:01<00:00, 111969288.18it/s]
86%|████████▌ | 146964480/170498071 [00:01<00:00, 112178628.87it/s]
93%|█████████▎| 158203904/170498071 [00:01<00:00, 109848184.93it/s]
99%|█████████▉| 169279488/170498071 [00:01<00:00, 110083013.64it/s]
100%|██████████| 170498071/170498071 [00:01<00:00, 104515008.19it/s]
Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10
Visualize cifar10 dataset
import matplotlib.pyplot as plt
data = next(iter(dataset_loader))
image = data['input']
label = data['output']
plt.figure(figsize=(5, 5))
plt.suptitle(f'{name.upper()} dataset Samples')
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(image[i].permute(2, 1, 0).cpu().numpy())
plt.title(f'Label: {label[i]}')
plt.axis('off')
plt.tight_layout()
plt.show()
Cave dataset
CAVE is a database of multispectral images that were used to emulate the GAP camera. The images are of a wide variety of real-world materials and objects. You can download the dataset from the following link: http://www.cs.columbia.edu/CAVE/databases/multispectral/ Once downloaded, you must extract the files and place them in the ‘data’ folder.
import requests
import zipfile
import os
dataset = CustomDataset("cave", path)
dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
Visualize cave dataset
def normalize(image):
return (image - image.min()) / (image.max() - image.min())
data = next(iter(dataset_loader))
rgb_image = data['input']
spec_image = data['output']
[_, _, M, N] = spec_image.shape
plt.figure(figsize=(8, 8))
for i in range(3):
coord1 = [randint(0, M), randint(0, N)]
coord2 = [randint(0, M), randint(0, N)]
plt.subplot(3, 3, (3 * i) + 1)
plt.imshow(normalize(rgb_image[i].permute(1, 2, 0).cpu().numpy()))
plt.title('rgb')
plt.axis('off')
plt.subplot(3, 3, (3 * i) + 2)
plt.imshow(normalize(spec_image[i, [18, 12, 8]].permute(1, 2, 0).cpu().numpy()))
plt.scatter(coord1[1], coord1[0], s=120, edgecolors='black')
plt.scatter(coord2[1], coord2[0], s=120, edgecolors='black')
plt.title('spec bands [18, 12, 8]')
plt.axis('off')
plt.subplot(3, 3, (3 * i) + 3)
plt.plot(normalize(spec_image[i, :, coord1[0], coord1[1]].cpu().numpy()), linewidth=2, label='p1')
plt.plot(normalize(spec_image[i, :, coord2[0], coord1[1]].cpu().numpy()), linewidth=2, label='p2')
plt.title('spec signatures')
plt.xlabel('Wavelength [nm]')
plt.grid()
plt.legend()
plt.tight_layout()
plt.show()
Total running time of the script: (0 minutes 21.345 seconds)