import torch
from torch.utils.data import Dataset
from torchvision import transforms
from colibri.data.sota_datasets import CaveDataset
from colibri.data.utils import BUILTIN_DATASETS, update_builtin_path, load_builtin_dataset
DATASET_HANDLERS = {
'cave': CaveDataset,
}
[docs]
class CustomDataset(Dataset):
r"""Custom dataset.
This class allows to load custom datasets and apply transformations to the data.
The datasets that can be currently loaded are:
- 'cifar10'
- 'mnist'
- 'fashion_mnist'
- 'cave'
This class is divided in two parts:
- builtin datasets: datasets that are predefined in the repository.
- custom datasets: datasets that are not predefined in the repository.
The builtin datasets are loaded using the function `load_builtin_dataset` from the module `colibri.data.utils`.
The custom datasets are loaded using the function `get_filenames` from the module `colibri.data.utils`.
The transformations are applied to the data using the `torchvision.transforms` module.
The default transformations are:
- input: `transforms.ToTensor()`
- output: `transforms.ToTensor()`
"""
def __init__(self, name: str, path: str= "data", transform_dict:dict={}, **kwargs_builtin):
r"""
Args:
name (string): Name of the dataset.
Current options are: ('cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'cave').
path (string): Path to directory with the dataset.
transform_dict (dict,object): Dictionary with the transformations to apply to the data.
kwargs_builtin (dict,object): Dictionary with the parameters to load the builtin datasets, each pytorch dataset has its own parameters please refer to the pytorch documentation.
"""
self.is_builtin_dataset = False
if name in BUILTIN_DATASETS: # builtin datasets
#assert kwargs_builtin is not None, "kwargs_builtin must be provided for builtin datasets"
self.is_builtin_dataset = True
path = update_builtin_path(name, path)
self.dataset = load_builtin_dataset(name, path, **kwargs_builtin)
self.len_dataset = len(self.dataset['input'])
else: # custom datasets
self.data_handler = DATASET_HANDLERS[name](path)
self.dataset_filenames = self.data_handler.get_list_paths()
self.data_reader = self.data_handler.load_item
self.len_dataset = len(self.dataset_filenames)
self.transform_dict = transform_dict
def __len__(self):
r"""
Returns:
int: Length of the dataset.
"""
return self.len_dataset
def __getitem__(self, idx):
r"""
Args:
idx (int): Index of the sample to load.
Returns:
dict: Dictionary with the data.
"""
# load sample
if self.is_builtin_dataset:
data = {key: value[idx] for key, value in self.dataset.items()}
else:
data = self.data_reader(self.dataset_filenames[idx])
# apply transformation
for key, value in self.transform_dict.items():
data[key] = self.transform_dict[key](value)
return data