diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 0fb4edcfab0a349546e5f7704c0783a1fe54b919..2cbc4b848d080de2d4500bd84111efb8905efcfa 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -4,6 +4,7 @@ import collections import functools +import itertools import logging import multiprocessing import sys @@ -17,12 +18,14 @@ import torchvision.transforms import tqdm from .typing import ( + ConcatDatabaseSplit, DatabaseSplit, DataLoader, Dataset, RawDataLoader, Sample, Transform, + TransformSequence, ) logger = logging.getLogger(__name__) @@ -72,9 +75,9 @@ class _DelayedLoadingDataset(Dataset): Parameters ---------- - split - An iterable containing the raw dataset samples loaded from the database - splits. + raw_dataset + An iterable containing the raw dataset samples representing one of the + database split datasets. loader An object instance that can load samples and labels from storage. @@ -86,11 +89,11 @@ class _DelayedLoadingDataset(Dataset): def __init__( self, - split: typing.Sequence[typing.Any], + raw_dataset: typing.Sequence[typing.Any], loader: RawDataLoader, - transforms: typing.Sequence[Transform] = [], + transforms: TransformSequence = [], ): - self.split = split + self.raw_dataset = raw_dataset self.loader = loader self.transform = torchvision.transforms.Compose(transforms) @@ -105,14 +108,14 @@ class _DelayedLoadingDataset(Dataset): def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" - return [self.loader.label(k) for k in self.split] + return [self.loader.label(k) for k in self.raw_dataset] def __getitem__(self, key: int) -> Sample: - tensor, metadata = self.loader.sample(self.split[key]) + tensor, metadata = self.loader.sample(self.raw_dataset[key]) return self.transform(tensor), metadata def __len__(self): - return len(self.split) + return len(self.raw_dataset) def __iter__(self): for x in range(len(self)): @@ -131,7 +134,7 @@ def _apply_loader_and_transforms( ---------- info - The sample information, as loaded from its split dictionary + The sample information, as loaded from its raw dataset dictionary load The raw-data loader function to use for loading the sample @@ -155,7 +158,7 @@ def _apply_loader_and_transforms( class _CachedDataset(Dataset): """Basically, a list of preloaded samples. - This dataset will load all samples from the split during construction + This dataset will load all samples from the raw dataset during construction instead of delaying that to the indexing. Beyong raw-data-loading, ``transforms`` given upon construction contribute to the cached samples. @@ -163,9 +166,9 @@ class _CachedDataset(Dataset): Parameters ---------- - split - An iterable containing the raw dataset samples loaded from the database - splits. + raw_dataset + An iterable containing the raw dataset samples representing one of the + database split datasets. loader An object instance that can load samples and labels from storage. @@ -184,10 +187,10 @@ class _CachedDataset(Dataset): def __init__( self, - split: typing.Sequence[typing.Any], + raw_dataset: typing.Sequence[typing.Any], loader: RawDataLoader, parallel: int = -1, - transforms: typing.Sequence[Transform] = [], + transforms: TransformSequence = [], ): self.loader = functools.partial( _apply_loader_and_transforms, @@ -197,14 +200,16 @@ class _CachedDataset(Dataset): if parallel < 0: self.data = [ - self.loader(k) for k in tqdm.tqdm(split, unit="sample") + self.loader(k) for k in tqdm.tqdm(raw_dataset, unit="sample") ] else: instances = parallel or multiprocessing.cpu_count() logger.info(f"Caching dataset using {instances} processes...") with multiprocessing.Pool(instances) as p: self.data = list( - tqdm.tqdm(p.imap(self.loader, split), total=len(split)) + tqdm.tqdm( + p.imap(self.loader, raw_dataset), total=len(raw_dataset) + ) ) # Estimates memory occupance @@ -229,8 +234,42 @@ class _CachedDataset(Dataset): return len(self.data) def __iter__(self): - for x in range(len(self)): - yield self[x] + yield from self.data + + +class _ConcatDataset(Dataset): + """A dataset that represents a concatenation of other cached or delayed + datasets. + + Parameters + ---------- + + datasets + An iterable over pre-instantiated datasets. + """ + + def __init__(self, datasets: typing.Sequence[Dataset]): + self._datasets = datasets + self._indices = [ + (i, j) # dataset relative position, sample relative position + for i in range(len(datasets)) + for j in range(len(datasets[i])) + ] + + def labels(self) -> list[int]: + """Returns the integer labels for all samples in the dataset.""" + return list(itertools.chain(*[k.labels() for k in self._datasets])) + + def __getitem__(self, key: int) -> Sample: + i, j = self._indices[key] + return self._datasets[i][j] + + def __len__(self): + return sum([len(k) for k in self._datasets]) + + def __iter__(self): + for dataset in self._datasets: + yield from dataset def _make_balanced_random_sampler( @@ -375,14 +414,15 @@ def _make_balanced_random_sampler( ) -class CachingDataModule(lightning.LightningDataModule): - """A conveninent data module with CSV or JSON protocol loading, mini- - batching, parallelisation and caching, all in one. +class ConcatDataModule(lightning.LightningDataModule): + """A conveninent data module with dictionary split loading, mini- batching, + parallelisation and caching, all in one. - Instances of this class load data-split (a.k.a. protocol) definitions for a - database, and can load the data from the disk. An optional caching - mechanism stores the data at associated CPU memory, which can improve data - serving while training and evaluating models. + Instances of this class can load and concatenate an arbitrary number of + data-split (a.k.a. protocol) definitions for (possibly disjoint) databases, + and can manage raw data-loading from disk. An optional caching mechanism + stores the data at associated CPU memory, which can improve data serving + while training and evaluating models. This datamodule defines basic operations to handle data loading and mini-batch handling within this package's framework. It can return @@ -390,31 +430,32 @@ class CachingDataModule(lightning.LightningDataModule): prediction and testing conditions. Parallelisation is handled by a simple input flag. - Users must implement the basic :py:meth:`setup` function, which is - parameterised by a single string enumeration containing: ``fit``, - ``validate``, ``test``, or ``predict``. - Parameters ---------- - database_split - A dictionary that contains string keys representing subset names, and - values that are iterables over sample representations (potentially on - disk). These objects are passed to the ``sample_loader`` for loading - the sample data (and metadata) in memory. The objects represented may - be of any format (e.g. list, dictionary, etc), for as long as the - ``sample_loader`` can properly handle it. To check the split and the - loader function works correctly, you may use - :py:func:`..dataset.check_database_split_loading`. As is, this class - expects at least one entry called ``train`` to exist in the input - dictionary. Optional entries are ``validation``, and ``test``. Entries - named ``monitor-...`` will be considered extra subsets that do not - influence any early stop criteria during training, and are just - monitored beyond the ``validation`` dataset. + splits + A dictionary that contains string keys representing dataset names, and + values that are iterables over a 2-tuple containing an iterable over + arbitrary, user-configurable sample representations (potentially on + disk or permanent storage), and :py:class:`RawDataLoader` (or "sample") + loader objects, which concretely implement a mechanism to load such + samples in memory, from permanent storage. - loader - An object instance that can load samples and labels from storage. + Sample representations on permanent storage may be of any iterable + format (e.g. list, dictionary, etc.), for as long as the assigned + :py:class:`RawDataLoader` can properly handle it. + + .. tip:: + + To check the split and the loader function works correctly, you may + use :py:func:`..dataset.check_database_split_loading`. + + This class expects at least one entry called ``train`` to exist in the + input dictionary. Optional entries are ``validation``, and ``test``. + Entries named ``monitor-...`` will be considered extra datasets that do + not influence any early stop criteria during training, and are just + monitored beyond the ``validation`` dataset. cache_samples If set, then issue raw data loading during ``prepare_data()``, and @@ -486,8 +527,7 @@ class CachingDataModule(lightning.LightningDataModule): def __init__( self, - database_split: DatabaseSplit, - raw_data_loader: RawDataLoader, + splits: ConcatDatabaseSplit, cache_samples: bool = False, balance_sampler_by_class: bool = False, batch_size: int = 1, @@ -499,8 +539,7 @@ class CachingDataModule(lightning.LightningDataModule): self.set_chunk_size(batch_size, batch_chunk_count) - self.database_split = database_split - self.raw_data_loader = raw_data_loader + self.splits = splits self.cache_samples = cache_samples self._train_sampler = None self.balance_sampler_by_class = balance_sampler_by_class @@ -581,6 +620,14 @@ class CachingDataModule(lightning.LightningDataModule): If set, then modifies the random sampler used during training and validation to balance sample picking probability, making sample across classes **and** datasets equitable. + + .. warning:: + + This method does **NOT** balance the sampler per dataset, in case + multiple datasets compose the same training set. It only balances + samples acording to their ground-truth (labels). If you'd like to + have samples balanced per dataset, then implement your own data + module inheriting from this one. """ return self._train_sampler is not None @@ -661,32 +708,45 @@ class CachingDataModule(lightning.LightningDataModule): f"Not re-instantiating it." ) return + + datasets: list[_CachedDataset | _DelayedLoadingDataset] = [] if self.cache_samples: logger.info( f"Loading dataset:`{name}` into memory (caching)." f" Trade-off: CPU RAM: more | Disk: less" ) - self._datasets[name] = _CachedDataset( - self.database_split[name], - self.raw_data_loader, - self.parallel, - self.model_transforms, - ) + for split, loader in self.splits[name]: + datasets.append( + _CachedDataset( + split, + loader, + self.parallel, + self.model_transforms, + ) + ) else: logger.info( f"Loading dataset:`{name}` without caching." f" Trade-off: CPU RAM: less | Disk: more" ) - self._datasets[name] = _DelayedLoadingDataset( - self.database_split[name], - self.raw_data_loader, - self.model_transforms, - ) + for split, loader in self.splits[name]: + datasets.append( + _DelayedLoadingDataset( + split, + loader, + self.model_transforms, + ) + ) + + if len(datasets) == 1: + self._datasets[name] = datasets[0] + else: + self._datasets[name] = _ConcatDataset(datasets) def _val_dataset_keys(self) -> list[str]: """Returns list of validation dataset names.""" return ["validation"] + [ - k for k in self.database_split.keys() if k.startswith("monitor-") + k for k in self.splits.keys() if k.startswith("monitor-") ] def setup(self, stage: str) -> None: @@ -727,7 +787,7 @@ class CachingDataModule(lightning.LightningDataModule): self._setup_dataset("test") elif stage == "predict": - for k in self.database_split: + for k in self.splits: self._setup_dataset(k) def teardown(self, stage: str) -> None: @@ -826,3 +886,53 @@ class CachingDataModule(lightning.LightningDataModule): ) for k in self._datasets } + + +class CachingDataModule(ConcatDataModule): + """A simplified version of our data module for a single split. + + Apart from construction, the behaviour of this data module is very similar + to its simpler counterpart, serving training, validation and test sets. + + + Parameters + ---------- + + database_split + A dictionary that contains string keys representing dataset names, and + values that are iterables over sample representations (potentially on + disk). These objects are passed to an unique :py:class:`RawDataLoader` + for loading the :py:class:`Sample` data (and metadata) in memory. It + therefore assumes the whole split is homogeneous and can be loaded in + the same way. + + .. tip:: + + To check the split and the loader function works correctly, you may + use :py:func:`..dataset.check_database_split_loading`. + + This class expects at least one entry called ``train`` to exist in the + input dictionary. Optional entries are ``validation``, and ``test``. + Entries named ``monitor-...`` will be considered extra datasets that do + not influence any early stop criteria during training, and are just + monitored beyond the ``validation`` dataset. + + raw_data_loader + An object instance that can load samples and labels from storage. + + **kwargs + List if named parameters matching those of + :py:class:`ConcatDataModule`, other than ``splits``. + """ + + def __init__( + self, + database_split: DatabaseSplit, + raw_data_loader: RawDataLoader, + **kwargs, + ): + splits = {k: [(v, raw_data_loader)] for k, v in database_split.items()} + super().__init__( + splits=splits, + **kwargs, + ) diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py index e8dc56a02c55687ecebefd84d2ab8c27e23cafc6..c1bcf73e6ff327668bbd98005ad4c73f5254976d 100644 --- a/src/ptbench/data/typing.py +++ b/src/ptbench/data/typing.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later - """Defines most common types used in code.""" import collections.abc @@ -51,8 +50,23 @@ TransformSequence = typing.Sequence[Transform] DatabaseSplit = collections.abc.Mapping[str, typing.Sequence[typing.Any]] """The definition of a database split. -A database split maps subset names to sequences of objects that, through -RawDataLoader's eventually become Samples in the processing pipeline. +A database split maps dataset (subset) names to sequences of objects +that, through :py:class:`RawDataLoader`s, eventually become +:py:class:`Sample`s in the processing pipeline. +""" + +ConcatDatabaseSplit = collections.abc.Mapping[ + str, + typing.Sequence[tuple[typing.Sequence[typing.Any], RawDataLoader]], +] +"""The definition of a complex database split composed of several other splits. + +A database split maps dataset (subset) names to sequences of objects +that, through :py:class:`RawDataLoader`s, eventually become +:py:class:`Sample`s in the processing pipeline. Objects of this subtype +allow the construction of complex splits composed of cannibalized parts +of other splits. Each split may be assigned a different +:py:class:`RawDataLoader`. """