Source code for colibri.data.utils

import os

import numpy as np
from PIL import Image

import torchvision
import torch
# Builtin datasets

BUILTIN_DATASETS = {

    'mnist': torchvision.datasets.MNIST,
    'fashion_mnist': torchvision.datasets.FashionMNIST,
    'cifar10': torchvision.datasets.CIFAR10,
    'cifar100': torchvision.datasets.CIFAR100
}


[docs] def update_builtin_path(name: str, path: str): r""" Update the built-in path by creating a new directory with the given name inside the specified path. Args: name (str): The name of the directory to be created. path (str): The path where the new directory will be created. Returns: str: The path of the newly created directory. """ path = os.path.join(path, name) os.makedirs(path, exist_ok=True) return path
[docs] def load_builtin_dataset(name: str, path: str, **kwargs): r""" Load a built-in dataset. Args: name (str): The name of the dataset. path (str): The path to save the dataset. **kwargs: Additional keyword arguments to pass to the pytorch dataset loader. Returns: dict: A dictionary containing the input and output data of the dataset. Raises: KeyError: If the specified dataset name is not found. """ train = kwargs['train'] if 'train' in kwargs else False download = kwargs['download'] if 'download' in kwargs else True builtin_dataset = BUILTIN_DATASETS[name](root=path, train=train, download=download) dataset = dict(input=builtin_dataset.data, output=builtin_dataset.targets) # transform dataset['input'] = (dataset['input'] / 255.).astype(np.float32) if dataset['input'].ndim != 4: dataset['input'] = dataset['input'].unsqueeze(1) dataset['input'] = np.transpose(dataset['input'], (0, 3, 2, 1)) dataset['input'] = torch.from_numpy(dataset['input']) dataset['output'] = torch.tensor(dataset['output']) return dataset