Demo Datasets.

In this example we show how to use the custom dataset class to load the predefined datasets in the repository.

Select Working Directory and Device

import os
from random import randint

from torch.utils.data import DataLoader

os.chdir(os.path.dirname(os.getcwd()))
print("Current Working Directory ", os.getcwd())
Current Working Directory  /home/runner/work/pycolibri/pycolibri

Load builtin dataset

from colibri.data.datasets import CustomDataset


name = 'fashion_mnist'  # ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'cave']
path = 'data'
batch_size = 128
builtin_train = True
builtin_download = True

dataset = CustomDataset(name, path, builtin_train, builtin_download)
dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 32.8k/26.4M [00:00<01:59, 221kB/s]
  0%|          | 65.5k/26.4M [00:00<02:01, 217kB/s]
  0%|          | 131k/26.4M [00:00<01:23, 314kB/s]
  1%|          | 229k/26.4M [00:00<00:58, 445kB/s]
  2%|▏         | 459k/26.4M [00:00<00:31, 826kB/s]
  3%|▎         | 918k/26.4M [00:00<00:16, 1.57MB/s]
  7%|▋         | 1.84M/26.4M [00:01<00:08, 3.02MB/s]
 14%|█▍        | 3.67M/26.4M [00:01<00:03, 5.88MB/s]
 26%|██▌       | 6.91M/26.4M [00:01<00:01, 10.7MB/s]
 41%|████▏     | 10.9M/26.4M [00:01<00:00, 15.5MB/s]
 56%|█████▋    | 14.9M/26.4M [00:01<00:00, 18.7MB/s]
 71%|███████   | 18.6M/26.4M [00:01<00:00, 19.8MB/s]
 84%|████████▍ | 22.2M/26.4M [00:02<00:00, 20.6MB/s]
 97%|█████████▋| 25.6M/26.4M [00:02<00:00, 21.0MB/s]
100%|██████████| 26.4M/26.4M [00:02<00:00, 12.2MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 186kB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 186kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|          | 32.8k/4.42M [00:00<00:21, 205kB/s]
  1%|▏         | 65.5k/4.42M [00:00<00:21, 204kB/s]
  3%|▎         | 131k/4.42M [00:00<00:14, 298kB/s]
  5%|▌         | 229k/4.42M [00:00<00:09, 422kB/s]
 10%|█         | 459k/4.42M [00:00<00:05, 786kB/s]
 20%|██        | 885k/4.42M [00:00<00:02, 1.43MB/s]
 40%|████      | 1.77M/4.42M [00:01<00:00, 2.77MB/s]
 81%|████████  | 3.57M/4.42M [00:01<00:00, 5.47MB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.43MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 42.4MB/s]

Visualize cifar10 dataset

import matplotlib.pyplot as plt

data = next(iter(dataset_loader))
image = data['input']
label = data['output']

plt.figure(figsize=(5, 5))
plt.suptitle(f'{name.upper()} dataset Samples - range: [{image.min()}, {image.max()}]')

for i in range(9):
    plt.subplot(3, 3, i + 1)
    plt.imshow(image[i].permute(2, 1, 0).cpu().numpy())
    plt.title(f'Label: {label[i]}')
    plt.axis('off')

plt.tight_layout()
plt.show()
FASHION_MNIST dataset Samples - range: [0.0, 1.0], Label: 9, Label: 0, Label: 0, Label: 3, Label: 0, Label: 2, Label: 7, Label: 2, Label: 5

Cave dataset

CAVE is a database of multispectral images that were used to emulate the GAP camera. The images are of a wide variety of real-world materials and objects. You can download the dataset from the following link: http://www.cs.columbia.edu/CAVE/databases/multispectral/ Once downloaded, you must extract the files and place them in the ‘data’ folder.

import requests
import zipfile
import os

dataset = CustomDataset("cave", path)

dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

Visualize cave dataset

def normalize(image):
    return (image - image.min()) / (image.max() - image.min())

data = next(iter(dataset_loader))
rgb_image = data['input']
spec_image = data['output']

[_, _, M, N] = spec_image.shape

plt.figure(figsize=(8, 8))
plt.suptitle(f'CAVE dataset Samples - spectral range: [{spec_image.min():.2f}, {spec_image.max():.2f}] - '
             f'RGB range: [{rgb_image.min():.2f}, {rgb_image.max():.2f}]')

for i in range(3):
    coord1 = [randint(0, M-1), randint(0, N-1)]
    coord2 = [randint(0, M-1), randint(0, N-1)]

    plt.subplot(3, 3, (3 * i) + 1)
    plt.imshow(normalize(rgb_image[i].permute(1, 2, 0).cpu().numpy()))
    plt.title('rgb')
    plt.axis('off')

    plt.subplot(3, 3, (3 * i) + 2)
    plt.imshow(normalize(spec_image[i, [18, 12, 8]].permute(1, 2, 0).cpu().numpy()))
    plt.scatter(coord1[1], coord1[0], s=120, edgecolors='black')
    plt.scatter(coord2[1], coord2[0], s=120, edgecolors='black')
    plt.title('spec bands [18, 12, 8]')
    plt.axis('off')

    plt.subplot(3, 3, (3 * i) + 3)
    plt.plot(normalize(spec_image[i, :, coord1[0], coord1[1]].cpu().numpy()), linewidth=2, label='p1')
    plt.plot(normalize(spec_image[i, :, coord2[0], coord1[1]].cpu().numpy()), linewidth=2, label='p2')
    plt.title('spec signatures')
    plt.xlabel('Wavelength [nm]')
    plt.grid()
    plt.legend()

plt.tight_layout()
plt.show()
CAVE dataset Samples - spectral range: [0.00, 0.97] - RGB range: [0.00, 1.00], rgb, spec bands [18, 12, 8], spec signatures, rgb, spec bands [18, 12, 8], spec signatures, rgb, spec bands [18, 12, 8], spec signatures

Total running time of the script: (0 minutes 26.653 seconds)

Gallery generated by Sphinx-Gallery