Demo Datasets.

In this example we show how to use the custom dataset class to load the predefined datasets in the repository.

Select Working Directory and Device

import os
from random import randint

from torch.utils.data import DataLoader

os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory ", os.getcwd())
Current Working Directory  /home/runner/work/pycolibri/pycolibri

Load builtin dataset

from colibri.data.datasets import CustomDataset


name = 'cifar10'  # ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'cave']
path = 'data'
batch_size = 128


dataset = CustomDataset(name, path)

dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz

  0%|          | 0/170498071 [00:00<?, ?it/s]
  1%|          | 917504/170498071 [00:00<00:20, 8478260.01it/s]
  6%|▋         | 10682368/170498071 [00:00<00:02, 59185127.29it/s]
 13%|█▎        | 22052864/170498071 [00:00<00:01, 83567336.05it/s]
 20%|█▉        | 33488896/170498071 [00:00<00:01, 95575958.55it/s]
 26%|██▋       | 44859392/170498071 [00:00<00:01, 101985501.50it/s]
 33%|███▎      | 56262656/170498071 [00:00<00:01, 106038001.56it/s]
 40%|███▉      | 67633152/170498071 [00:00<00:00, 108489826.34it/s]
 46%|████▋     | 79036416/170498071 [00:00<00:00, 110160458.04it/s]
 53%|█████▎    | 90177536/170498071 [00:00<00:00, 110414727.23it/s]
 60%|█████▉    | 101613568/170498071 [00:01<00:00, 111494616.91it/s]
 66%|██████▋   | 113049600/170498071 [00:01<00:00, 112288315.66it/s]
 73%|███████▎  | 124420096/170498071 [00:01<00:00, 112633534.81it/s]
 80%|███████▉  | 135692288/170498071 [00:01<00:00, 111969288.18it/s]
 86%|████████▌ | 146964480/170498071 [00:01<00:00, 112178628.87it/s]
 93%|█████████▎| 158203904/170498071 [00:01<00:00, 109848184.93it/s]
 99%|█████████▉| 169279488/170498071 [00:01<00:00, 110083013.64it/s]
100%|██████████| 170498071/170498071 [00:01<00:00, 104515008.19it/s]
Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10

Visualize cifar10 dataset

import matplotlib.pyplot as plt

data = next(iter(dataset_loader))
image = data['input']
label = data['output']

plt.figure(figsize=(5, 5))
plt.suptitle(f'{name.upper()} dataset Samples')

for i in range(9):
    plt.subplot(3, 3, i + 1)
    plt.imshow(image[i].permute(2, 1, 0).cpu().numpy())
    plt.title(f'Label: {label[i]}')
    plt.axis('off')

plt.tight_layout()
plt.show()
CIFAR10 dataset Samples, Label: 3, Label: 8, Label: 8, Label: 0, Label: 6, Label: 6, Label: 1, Label: 6, Label: 3

Cave dataset

CAVE is a database of multispectral images that were used to emulate the GAP camera. The images are of a wide variety of real-world materials and objects. You can download the dataset from the following link: http://www.cs.columbia.edu/CAVE/databases/multispectral/ Once downloaded, you must extract the files and place them in the ‘data’ folder.

import requests
import zipfile
import os

dataset = CustomDataset("cave", path)

dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

Visualize cave dataset

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

data = next(iter(dataset_loader))
rgb_image = data['input']
spec_image = data['output']

[_, _, M, N] = spec_image.shape

plt.figure(figsize=(8, 8))

for i in range(3):
    coord1 = [randint(0, M), randint(0, N)]
    coord2 = [randint(0, M), randint(0, N)]

    plt.subplot(3, 3, (3 * i) + 1)
    plt.imshow(normalize(rgb_image[i].permute(1, 2, 0).cpu().numpy()))
    plt.title('rgb')
    plt.axis('off')

    plt.subplot(3, 3, (3 * i) + 2)
    plt.imshow(normalize(spec_image[i, [18, 12, 8]].permute(1, 2, 0).cpu().numpy()))
    plt.scatter(coord1[1], coord1[0], s=120, edgecolors='black')
    plt.scatter(coord2[1], coord2[0], s=120, edgecolors='black')
    plt.title('spec bands [18, 12, 8]')
    plt.axis('off')

    plt.subplot(3, 3, (3 * i) + 3)
    plt.plot(normalize(spec_image[i, :, coord1[0], coord1[1]].cpu().numpy()), linewidth=2, label='p1')
    plt.plot(normalize(spec_image[i, :, coord2[0], coord1[1]].cpu().numpy()), linewidth=2, label='p2')
    plt.title('spec signatures')
    plt.xlabel('Wavelength [nm]')
    plt.grid()
    plt.legend()

plt.tight_layout()
plt.show()
rgb, spec bands [18, 12, 8], spec signatures, rgb, spec bands [18, 12, 8], spec signatures, rgb, spec bands [18, 12, 8], spec signatures

Total running time of the script: (0 minutes 21.345 seconds)

Gallery generated by Sphinx-Gallery