# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import collections import logging import multiprocessing import sys import typing import lightning import torch import torch.backends import torch.utils.data import torchvision.transforms import tqdm from .typing import ( DatabaseSplit, DataLoader, Dataset, RawDataLoader, Sample, Transform, ) logger = logging.getLogger(__name__) def _setup_dataloader_multiproc_parameters( parallel: int, ) -> dict[str, typing.Any]: """Returns a dictionary containing pytorch arguments to be used in data loaders. It sets the parameter ``num_workers`` to match the expected pytorch representation. For macOS machines, it also sets the ``multiprocessing_context`` to use ``spawn`` instead of the default. The mapping between the command-line interface ``parallel`` setting works like this: .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation :widths: 15 15 70 :header-rows: 1 * - CLI ``parallel`` - :py:class:`torch.utils.data.DataLoader` ``kwargs`` - Comments * - ``<0`` - 0 - Disables multiprocessing entirely, executes everything within the same processing context * - ``0`` - :py:func:`multiprocessing.cpu_count` - Runs mini-batch data loading on as many external processes as CPUs available in the current machine * - ``>=1`` - ``parallel`` - Runs mini-batch data loading on as many external processes as set on ``parallel`` """ retval: dict[str, typing.Any] = dict() if parallel < 0: retval["num_workers"] = 0 else: retval["num_workers"] = parallel or multiprocessing.cpu_count() if retval["num_workers"] > 0 and sys.platform == "darwin": retval["multiprocessing_context"] = multiprocessing.get_context("spawn") return retval class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. This list mimics a pytorch Dataset, except raw data loading is done on-the-fly, as the samples are requested through the bracket operator. Parameters ---------- split An iterable containing the raw dataset samples loaded from the database splits. loader An object instance that can load samples and labels from storage. transforms A set of transforms that should be applied on-the-fly for this dataset, to fit the output of the raw-data-loader to the model of interest. """ def __init__( self, split: typing.Sequence[typing.Any], loader: RawDataLoader, transforms: typing.Sequence[Transform] = [], ): self.split = split self.loader = loader self.transform = torchvision.transforms.Compose(transforms) def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" return [self.loader.label(k) for k in self.split] def __getitem__(self, key: int) -> Sample: tensor, metadata = self.loader.sample(self.split[key]) return self.transform(tensor), metadata def __len__(self): return len(self.split) def __iter__(self): for x in range(len(self)): yield self[x] class _CachedDataset(Dataset): """Basically, a list of preloaded samples. This dataset will load all samples from the split during construction instead of delaying that to the indexing. Beyong raw-data-loading, ``transforms`` given upon construction contribute to the cached samples. Parameters ---------- split An iterable containing the raw dataset samples loaded from the database splits. loader An object instance that can load samples and labels from storage. parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. transforms A set of transforms that should be applied to the cached samples for this dataset, to fit the output of the raw-data-loader to the model of interest. """ def __init__( self, split: typing.Sequence[typing.Any], loader: RawDataLoader, parallel: int = -1, transforms: typing.Sequence[Transform] = [], ): self.transform = torchvision.transforms.Compose(transforms) if parallel < 0: self.data = [ loader.sample(k) for k in tqdm.tqdm(split, unit="sample") ] else: instances = parallel or multiprocessing.cpu_count() logger.info(f"Caching dataset using {instances} processes...") with multiprocessing.Pool(instances) as p: self.data = list( tqdm.tqdm(p.imap(loader.sample, split), total=len(split)) ) def labels(self) -> list[int]: """Returns the integer labels for all samples in the dataset.""" return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: tensor, metadata = self.data[key] return self.transform(tensor), metadata def __len__(self): return len(self.data) def __iter__(self): for x in range(len(self)): yield self[x] def _make_balanced_random_sampler( dataset: Dataset, target: str = "label", ) -> torch.utils.data.WeightedRandomSampler: """Generates a pytorch sampler that samples according to class probabilities. This function takes as input a torch Dataset, and computes the weights to balance each class in the dataset, and the datasets themselves if one passes a :py:class:`torch.utils.data.ConcatDataset`. In this implementation, we balance **both** class and dataset-origin probabilities, what you expect for a truly *equitable* random sampler. Take this example for illustration: * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1 * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1 So: | Dataset | Target | Samples | Weight | Normalised weight | +---------+--------+---------+--------+-------------------+ | 1 | 0 | 9 | 1/9 | 1/36 | | 1 | 1 | 1 | 1/1 | 1/4 | | 2 | 0 | 3 | 1/3 | 1/12 | | 2 | 1 | 3 | 1/3 | 1/12 | Legend: * Weight: the weights computed by this method * Normalised weight: the weight per sample used by the random sampler, after normalising the weights by the sum of all weights in the concatenated dataset, such that the sum of all normalized weights times the number of samples is 1. The properties of this algorithm are as follows: 1. The probability of picking a sample from any target is the same (0.5 in this case). To verify this, notice that the probability of picking a sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`. 2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is 3 times higher than those from Dataset 1. As there are 3 times less samples in Dataset 2 with ``target=0``, this makes choosing samples from Dataset 1 proportionally less likely. 3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is 3 times lower than those from Dataset 1. As there are 3 times less samples in Dataset 1 with ``target=1``, this makes choosing samples from Dataset 2 proportionally less likely. This function assumes targets are stored on a dictionary entry named ``target`` inside the metadata information for the :py:type:``Sample``, and that its value is integer. We then instantiate a pytorch sampler using the inverse probabilities (the more samples of a class, the less likely it becomes to be sampled. Parameters ---------- dataset An instance of torch Dataset. :py:class:`torch.utils.data.ConcatDataset` are supported. target The name of a metadata key pointing to an integer property that allows balancing the dataset. Returns ------- sampler A sampler, to be used in a dataloader equipped with the same dataset used to calculate the relative sample weights. Raises ------ RuntimeError If requested to balance a dataset (single, not-concatenated) without an existing target. """ def _calculate_weights(targets: list[int]) -> list[float]: counts = collections.Counter(targets) weights = {k: 1.0 / v for k, v in counts.items()} return [weights[k] for k in targets] if isinstance(dataset, torch.utils.data.ConcatDataset): # There are two possible cases: targets/no-targets metadata_example = dataset.datasets[0][0][1] if target in metadata_example and isinstance( metadata_example[target], int ): # there are integer targets, let's balance with those logger.info( f"Balancing sample selection probabilities **and** " f"concatenated-datasets using metadata targets `{target}`" ) targets = [ k for ds in dataset.datasets for k in typing.cast(Dataset, ds).labels() ] weights = _calculate_weights(targets) else: logger.warning( f"Balancing samples **and** concatenated-datasets " f"WITHOUT metadata targets (`{target}` not available)" ) weights = [ k for ds in dataset.datasets for k in len(typing.cast(typing.Sized, ds)) * [1.0 / len(typing.cast(typing.Sized, ds))] ] pass else: metadata_example = dataset[0][1] if target in metadata_example and isinstance( metadata_example[target], int ): logger.info( f"Balancing samples from dataset using metadata " f"targets `{target}`" ) weights = _calculate_weights(dataset.labels()) else: raise RuntimeError( f"Cannot balance samples without metadata targets `{target}`" ) return torch.utils.data.WeightedRandomSampler( weights, len(weights), replacement=True ) class CachingDataModule(lightning.LightningDataModule): """A conveninent data module with CSV or JSON protocol loading, mini- batching, parallelisation and caching, all in one. Instances of this class load data-split (a.k.a. protocol) definitions for a database, and can load the data from the disk. An optional caching mechanism stores the data at associated CPU memory, which can improve data serving while training and evaluating models. This datamodule defines basic operations to handle data loading and mini-batch handling within this package's framework. It can return :py:class:`torch.utils.data.DataLoader` objects for training, validation, prediction and testing conditions. Parallelisation is handled by a simple input flag. Users must implement the basic :py:meth:`setup` function, which is parameterised by a single string enumeration containing: ``fit``, ``validate``, ``test``, or ``predict``. Parameters ---------- database_split A dictionary that contains string keys representing subset names, and values that are iterables over sample representations (potentially on disk). These objects are passed to the ``sample_loader`` for loading the sample data (and metadata) in memory. The objects represented may be of any format (e.g. list, dictionary, etc), for as long as the ``sample_loader`` can properly handle it. To check the split and the loader function works correctly, you may use :py:func:`..dataset.check_database_split_loading`. As is, this class expects at least one entry called ``train`` to exist in the input dictionary. Optional entries are ``validation``, and ``test``. Entries named ``monitor-...`` will be considered extra subsets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. loader An object instance that can load samples and labels from storage. cache_samples If set, then issue raw data loading during ``prepare_data()``, and serves samples from CPU memory. Otherwise, loads samples from disk on demand. Running from CPU memory will offer increased speeds in exchange for CPU memory. Sufficient CPU memory must be available before you set this attribute to ``True``. It is typicall useful for relatively small datasets. balance_sampler_by_class If set, then modifies the random sampler used during training and validation to balance sample picking probability, making sample across classes **and** datasets equitable. model_transforms A list of transforms (torch modules) that will be applied after raw-data-loading, and just before data is fed into the model or eventual data-augmentation transformations for all data loaders produced by this data module. This part of the pipeline receives data as output by the raw-data-loader, or model-related transforms (e.g. resize adaptions), if any is specified. batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the batch is larger than the total number of samples available for training, this value is truncated. If this number is smaller, then batches of the specified size are created and fed to the network until there are no more new samples to feed (epoch is finished). If the total number of training samples is not a multiple of the batch-size, the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. batch_chunk_count Number of chunks in every batch (this parameter affects memory requirements for the network). The number of samples loaded for every iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` needs to be divisible by ``batch_chunk_count``, otherwise an error will be raised. This parameter is used to reduce number of samples loaded in each iteration, in order to reduce the memory usage in exchange for processing time (more iterations). This is specially interesting whe one is running with GPUs with limited RAM. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is broken into batch-chunk-count pieces, and gradients are accumulated to complete each batch. drop_incomplete_batch If set, then may drop the last batch in an epoch, in case it is incomplete. If you set this option, you should also consider increasing the total number of epochs of training, as the total number of training steps may be reduced. parallel Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. """ DatasetDictionary = dict[str, Dataset] def __init__( self, database_split: DatabaseSplit, raw_data_loader: RawDataLoader, cache_samples: bool = False, balance_sampler_by_class: bool = False, model_transforms: list[Transform] = [], batch_size: int = 1, batch_chunk_count: int = 1, drop_incomplete_batch: bool = False, parallel: int = -1, ): super().__init__() self.set_chunk_size(batch_size, batch_chunk_count) self.database_split = database_split self.raw_data_loader = raw_data_loader self.cache_samples = cache_samples self._train_sampler = None self.balance_sampler_by_class = balance_sampler_by_class self.model_transforms = model_transforms self.drop_incomplete_batch = drop_incomplete_batch self.parallel = parallel # immutable, otherwise would need to call self.pin_memory = ( torch.cuda.is_available() or torch.backends.mps.is_available() ) # should only be true if GPU available and using it # datasets that have been setup() for the current stage self._datasets: CachingDataModule.DatasetDictionary = {} @property def parallel(self) -> int: """Whether to use multiprocessing for data loading. Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. """ return self._parallel @parallel.setter def parallel(self, value: int) -> None: self._parallel = value self._dataloader_multiproc = _setup_dataloader_multiproc_parameters( value ) # datasets that have been setup() for the current stage self._datasets = {} @property def balance_sampler_by_class(self): """Whether to balance samples across labels/datasets. If set, then modifies the random sampler used during training and validation to balance sample picking probability, making sample across classes **and** datasets equitable. """ return self._train_sampler is not None @balance_sampler_by_class.setter def balance_sampler_by_class(self, value: bool): if value: if "train" not in self._datasets: self._setup_dataset("train") self._train_sampler = _make_balanced_random_sampler( self._datasets["train"] ) else: self._train_sampler = None def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None: """Coherently sets the batch-chunk-size after validation. Parameters ---------- batch_size Number of samples in every **training** batch (this parameter affects memory requirements for the network). If the number of samples in the batch is larger than the total number of samples available for training, this value is truncated. If this number is smaller, then batches of the specified size are created and fed to the network until there are no more new samples to feed (epoch is finished). If the total number of training samples is not a multiple of the batch-size, the last batch will be smaller than the first, unless ``drop_incomplete_batch`` is set to ``true``, in which case this batch is not used. batch_chunk_count Number of chunks in every batch (this parameter affects memory requirements for the network). The number of samples loaded for every iteration will be ``batch_size/batch_chunk_count``. ``batch_size`` needs to be divisible by ``batch_chunk_count``, otherwise an error will be raised. This parameter is used to reduce number of samples loaded in each iteration, in order to reduce the memory usage in exchange for processing time (more iterations). This is specially interesting whe one is running with GPUs with limited RAM. The default of 1 forces the whole batch to be processed at once. Otherwise the batch is broken into batch-chunk-count pieces, and gradients are accumulated to complete each batch. """ # validation if batch_size % batch_chunk_count != 0: raise RuntimeError( f"batch_size ({batch_size}) must be divisible by " f"batch_chunk_size ({batch_chunk_count})." ) self._batch_size = batch_size self._batch_chunk_count = batch_chunk_count self._chunk_size = self._batch_size // self._batch_chunk_count def _setup_dataset(self, name: str) -> None: """Sets-up a single dataset from the input data split. Parameters ---------- name Name of the dataset to setup. """ if name in self._datasets: logger.info( f"Dataset `{name}` is already setup. " f"Not re-instantiating it." ) return if self.cache_samples: logger.info( f"Loading dataset:`{name}` into memory (caching)." f" Trade-off: CPU RAM: more | Disk: less" ) self._datasets[name] = _CachedDataset( self.database_split[name], self.raw_data_loader, self.parallel, self.model_transforms, ) else: logger.info( f"Loading dataset:`{name}` without caching." f" Trade-off: CPU RAM: less | Disk: more" ) self._datasets[name] = _DelayedLoadingDataset( self.database_split[name], self.raw_data_loader, self.model_transforms, ) def _val_dataset_keys(self) -> list[str]: """Returns list of validation dataset names.""" return ["validation"] + [ k for k in self.database_split.keys() if k.startswith("monitor-") ] def setup(self, stage: str) -> None: """Sets up datasets for different tasks on the pipeline. This method should setup (load, pre-process, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict), and keep them ready to be used on one of the `_dataloader()` functions that are pertinent for such stage. If you have set ``cache_samples``, samples are loaded at this stage and cached in memory. Parameters ---------- stage Name of the stage to which the setup is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: * ``fit``: uses both train and validation datasets * ``validate``: uses only the validation dataset * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ if stage == "fit": for k in ["train"] + self._val_dataset_keys(): self._setup_dataset(k) elif stage == "validate": for k in self._val_dataset_keys(): self._setup_dataset(k) elif stage == "test": self._setup_dataset("test") elif stage == "predict": self._setup_dataset("test") def teardown(self, stage: str) -> None: """Unset-up datasets for different tasks on the pipeline. This method unsets (unload, remove from memory, etc) all datasets required for a particular ``stage`` (fit, validate, test, predict). If you have set ``cache_samples``, samples are loaded, this may effectivley release all the associated memory. Parameters ---------- stage Name of the stage to which the teardown is applicable. Can be one of ``fit``, ``validate``, ``test`` or ``predict``. Each stage typically uses the following data loaders: * ``fit``: uses both train and validation datasets * ``validate``: uses only the validation dataset * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ self._datasets = {} def train_dataloader(self) -> DataLoader: """Returns the train data loader.""" return torch.utils.data.DataLoader( self._datasets["train"], shuffle=(self._train_sampler is None), batch_size=self._chunk_size, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=self._train_sampler, **self._dataloader_multiproc, ) def unshuffled_train_dataloader(self) -> DataLoader: """Returns the train data loader without shuffling.""" return torch.utils.data.DataLoader( self._datasets["train"], shuffle=False, batch_size=self._chunk_size, drop_last=False, **self._dataloader_multiproc, ) def val_dataloader(self) -> dict[str, DataLoader]: """Returns the validation data loader(s)""" validation_loader_opts = { "batch_size": self._chunk_size, "shuffle": False, "drop_last": self.drop_incomplete_batch, "pin_memory": self.pin_memory, } validation_loader_opts.update(self._dataloader_multiproc) return { k: torch.utils.data.DataLoader( self._datasets[k], **validation_loader_opts ) for k in self._val_dataset_keys() } def test_dataloader(self) -> dict[str, DataLoader]: """Returns the test data loader(s)""" return dict( test=torch.utils.data.DataLoader( self._datasets["test"], batch_size=self._chunk_size, shuffle=False, drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, **self._dataloader_multiproc, ) ) def predict_dataloader(self) -> dict[str, DataLoader]: """Returns the prediction data loader(s)""" return self.test_dataloader()