André Anjos authoredAndré Anjos authored
datamodule.py 26.29 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# SPDX-License-Identifier: GPL-3.0-or-later
import collections
import logging
import multiprocessing
import sys
import typing
import lightning
import torch
import torch.backends
import torch.utils.data
import torchvision.transforms
import tqdm
from .typing import (
logger = logging.getLogger(__name__)
def _setup_dataloader_multiproc_parameters(
parallel: int,
) -> dict[str, typing.Any]:
"""Returns a dictionary containing pytorch arguments to be used in data
It sets the parameter ``num_workers`` to match the expected pytorch
representation. For macOS machines, it also sets the
``multiprocessing_context`` to use ``spawn`` instead of the default.
The mapping between the command-line interface ``parallel`` setting works
like this:
.. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
:widths: 15 15 70
:header-rows: 1
* - CLI ``parallel``
- :py:class:`torch.utils.data.DataLoader` ``kwargs``
- Comments
* - ``<0``
- 0
- Disables multiprocessing entirely, executes everything within the
same processing context
* - ``0``
- :py:func:`multiprocessing.cpu_count`
- Runs mini-batch data loading on as many external processes as CPUs
available in the current machine
* - ``>=1``
- ``parallel``
- Runs mini-batch data loading on as many external processes as set on
retval: dict[str, typing.Any] = dict()
if parallel < 0:
retval["num_workers"] = 0
retval["num_workers"] = parallel or multiprocessing.cpu_count()
if retval["num_workers"] > 0 and sys.platform == "darwin":
retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
return retval
class _DelayedLoadingDataset(Dataset):
"""A list that loads its samples on demand.
This list mimics a pytorch Dataset, except raw data loading is done
on-the-fly, as the samples are requested through the bracket operator.
An iterable containing the raw dataset samples loaded from the database
An object instance that can load samples and labels from storage.
A set of transforms that should be applied on-the-fly for this dataset,
to fit the output of the raw-data-loader to the model of interest.
def __init__(
split: typing.Sequence[typing.Any],
loader: RawDataLoader,
transforms: typing.Sequence[Transform] = [],
self.split = split
self.loader = loader
self.transform = torchvision.transforms.Compose(transforms)
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [self.loader.label(k) for k in self.split]
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.split[key])
return self.transform(tensor), metadata
def __len__(self):
return len(self.split)
def __iter__(self):
for x in range(len(self)):
yield self[x]
class _CachedDataset(Dataset):
"""Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction
instead of delaying that to the indexing. Beyong raw-data-loading,
``transforms`` given upon construction contribute to the cached samples.
An iterable containing the raw dataset samples loaded from the database
An object instance that can load samples and labels from storage.
Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading
instances as processing cores as available in the system. Set to >= 1
to enable that many multiprocessing instances for data loading.
A set of transforms that should be applied to the cached samples for
this dataset, to fit the output of the raw-data-loader to the model of
def __init__(
split: typing.Sequence[typing.Any],
loader: RawDataLoader,
parallel: int = -1,
transforms: typing.Sequence[Transform] = [],
self.transform = torchvision.transforms.Compose(transforms)
if parallel < 0:
self.data = [
loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
instances = parallel or multiprocessing.cpu_count()
logger.info(f"Caching dataset using {instances} processes...")
with multiprocessing.Pool(instances) as p:
self.data = list(
tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [k[1]["label"] for k in self.data]
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.data[key]
return self.transform(tensor), metadata
def __len__(self):
return len(self.data)
def __iter__(self):
for x in range(len(self)):
yield self[x]
def _make_balanced_random_sampler(
dataset: Dataset,
target: str = "label",
) -> torch.utils.data.WeightedRandomSampler:
"""Generates a pytorch sampler that samples according to class
This function takes as input a torch Dataset, and computes the weights to
balance each class in the dataset, and the datasets themselves if one
passes a :py:class:`torch.utils.data.ConcatDataset`.
In this implementation, we balance **both** class and dataset-origin
probabilities, what you expect for a truly *equitable* random sampler.
Take this example for illustration:
* Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
* Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1
| Dataset | Target | Samples | Weight | Normalised weight |
| 1 | 0 | 9 | 1/9 | 1/36 |
| 1 | 1 | 1 | 1/1 | 1/4 |
| 2 | 0 | 3 | 1/3 | 1/12 |
| 2 | 1 | 3 | 1/3 | 1/12 |
* Weight: the weights computed by this method
* Normalised weight: the weight per sample used by the random sampler,
after normalising the weights by the sum of all weights in the
concatenated dataset, such that the sum of all normalized weights times
the number of samples is 1.
The properties of this algorithm are as follows:
1. The probability of picking a sample from any target is the same (0.5 in
this case). To verify this, notice that the probability of picking a
sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is
3 times higher than those from Dataset 1. As there are 3 times less
samples in Dataset 2 with ``target=0``, this makes choosing samples from
Dataset 1 proportionally less likely.
3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is
3 times lower than those from Dataset 1. As there are 3 times less
samples in Dataset 1 with ``target=1``, this makes choosing samples from
Dataset 2 proportionally less likely.
This function assumes targets are stored on a dictionary entry named
``target`` inside the metadata information for the :py:type:``Sample``, and
that its value is integer.
We then instantiate a pytorch sampler using the inverse probabilities (the
more samples of a class, the less likely it becomes to be sampled.
An instance of torch Dataset.
:py:class:`torch.utils.data.ConcatDataset` are supported.
The name of a metadata key pointing to an integer property that allows
balancing the dataset.
A sampler, to be used in a dataloader equipped with the same dataset
used to calculate the relative sample weights.
If requested to balance a dataset (single, not-concatenated) without an
existing target.
def _calculate_weights(targets: list[int]) -> list[float]:
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
return [weights[k] for k in targets]
if isinstance(dataset, torch.utils.data.ConcatDataset):
# There are two possible cases: targets/no-targets
metadata_example = dataset.datasets[0][0][1]
if target in metadata_example and isinstance(
metadata_example[target], int
# there are integer targets, let's balance with those
f"Balancing sample selection probabilities **and** "
f"concatenated-datasets using metadata targets `{target}`"
targets = [
for ds in dataset.datasets
for k in typing.cast(Dataset, ds).labels()
weights = _calculate_weights(targets)
f"Balancing samples **and** concatenated-datasets "
f"WITHOUT metadata targets (`{target}` not available)"
weights = [
for ds in dataset.datasets
for k in len(typing.cast(typing.Sized, ds))
* [1.0 / len(typing.cast(typing.Sized, ds))]
metadata_example = dataset[0][1]
if target in metadata_example and isinstance(
metadata_example[target], int
f"Balancing samples from dataset using metadata "
f"targets `{target}`"
weights = _calculate_weights(dataset.labels())
raise RuntimeError(
f"Cannot balance samples without metadata targets `{target}`"
return torch.utils.data.WeightedRandomSampler(
weights, len(weights), replacement=True
class CachingDataModule(lightning.LightningDataModule):
"""A conveninent data module with CSV or JSON protocol loading, mini-
batching, parallelisation and caching, all in one.
Instances of this class load data-split (a.k.a. protocol) definitions for a
database, and can load the data from the disk. An optional caching
mechanism stores the data at associated CPU memory, which can improve data
serving while training and evaluating models.
This datamodule defines basic operations to handle data loading and
mini-batch handling within this package's framework. It can return
:py:class:`torch.utils.data.DataLoader` objects for training, validation,
prediction and testing conditions. Parallelisation is handled by a simple
input flag.
Users must implement the basic :py:meth:`setup` function, which is
parameterised by a single string enumeration containing: ``fit``,
``validate``, ``test``, or ``predict``.
A dictionary that contains string keys representing subset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to the ``sample_loader`` for loading
the sample data (and metadata) in memory. The objects represented may
be of any format (e.g. list, dictionary, etc), for as long as the
``sample_loader`` can properly handle it. To check the split and the
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
dictionary. Optional entries are ``validation``, and ``test``. Entries
named ``monitor-...`` will be considered extra subsets that do not
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
An object instance that can load samples and labels from storage.
If set, then issue raw data loading during ``prepare_data()``, and
serves samples from CPU memory. Otherwise, loads samples from disk on
demand. Running from CPU memory will offer increased speeds in exchange
for CPU memory. Sufficient CPU memory must be available before you set
this attribute to ``True``. It is typicall useful for relatively small
If set, then modifies the random sampler used during training and
validation to balance sample picking probability, making sample
across classes **and** datasets equitable.
A list of transforms (torch modules) that will be applied after
raw-data-loading, and just before data is fed into the model or
eventual data-augmentation transformations for all data loaders
produced by this data module. This part of the pipeline receives data
as output by the raw-data-loader, or model-related transforms (e.g.
resize adaptions), if any is specified.
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is specially interesting whe
one is running with GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
If set, then may drop the last batch in an epoch, in case it is
incomplete. If you set this option, you should also consider
increasing the total number of epochs of training, as the total number
of training steps may be reduced.
Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading
instances as processing cores as available in the system. Set to >= 1
to enable that many multiprocessing instances for data loading.
DatasetDictionary = dict[str, Dataset]
def __init__(
database_split: DatabaseSplit,
raw_data_loader: RawDataLoader,
cache_samples: bool = False,
balance_sampler_by_class: bool = False,
model_transforms: list[Transform] = [],
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
parallel: int = -1,
self.set_chunk_size(batch_size, batch_chunk_count)
self.database_split = database_split
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples
self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class
self.model_transforms = model_transforms
self.drop_incomplete_batch = drop_incomplete_batch
self.parallel = parallel # immutable, otherwise would need to call
self.pin_memory = (
torch.cuda.is_available() or torch.backends.mps.is_available()
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
self._datasets: CachingDataModule.DatasetDictionary = {}
def parallel(self) -> int:
"""Whether to use multiprocessing for data loading.
Use multiprocessing for data loading: if set to -1 (default),
disables multiprocessing data loading. Set to 0 to enable as
many data loading instances as processing cores as available in
the system. Set to >= 1 to enable that many multiprocessing
instances for data loading.
return self._parallel
def parallel(self, value: int) -> None:
self._parallel = value
self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
# datasets that have been setup() for the current stage
self._datasets = {}
def balance_sampler_by_class(self):
"""Whether to balance samples across labels/datasets.
If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making
sample across classes **and** datasets equitable.
return self._train_sampler is not None
def balance_sampler_by_class(self, value: bool):
if value:
if "train" not in self._datasets:
self._train_sampler = _make_balanced_random_sampler(
self._train_sampler = None
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently sets the batch-chunk-size after validation.
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is specially interesting whe
one is running with GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
# validation
if batch_size % batch_chunk_count != 0:
raise RuntimeError(
f"batch_size ({batch_size}) must be divisible by "
f"batch_chunk_size ({batch_chunk_count})."
self._batch_size = batch_size
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
def _setup_dataset(self, name: str) -> None:
"""Sets-up a single dataset from the input data split.
Name of the dataset to setup.
if name in self._datasets:
f"Dataset `{name}` is already setup. "
f"Not re-instantiating it."
if self.cache_samples:
f"Loading dataset:`{name}` into memory (caching)."
f" Trade-off: CPU RAM: more | Disk: less"
self._datasets[name] = _CachedDataset(
f"Loading dataset:`{name}` without caching."
f" Trade-off: CPU RAM: less | Disk: more"
self._datasets[name] = _DelayedLoadingDataset(
def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
def setup(self, stage: str) -> None:
"""Sets up datasets for different tasks on the pipeline.
This method should setup (load, pre-process, etc) all datasets required
for a particular ``stage`` (fit, validate, test, predict), and keep
them ready to be used on one of the `_dataloader()` functions that are
pertinent for such stage.
If you have set ``cache_samples``, samples are loaded at this stage and
cached in memory.
Name of the stage to which the setup is applicable. Can be one of
``fit``, ``validate``, ``test`` or ``predict``. Each stage
typically uses the following data loaders:
* ``fit``: uses both train and validation datasets
* ``validate``: uses only the validation dataset
* ``test``: uses only the test dataset
* ``predict``: uses only the test dataset
if stage == "fit":
for k in ["train"] + self._val_dataset_keys():
elif stage == "validate":
for k in self._val_dataset_keys():
elif stage == "test":
elif stage == "predict":
def teardown(self, stage: str) -> None:
"""Unset-up datasets for different tasks on the pipeline.
This method unsets (unload, remove from memory, etc) all datasets required
for a particular ``stage`` (fit, validate, test, predict).
If you have set ``cache_samples``, samples are loaded, this may
effectivley release all the associated memory.
Name of the stage to which the teardown is applicable. Can be one of
``fit``, ``validate``, ``test`` or ``predict``. Each stage
typically uses the following data loaders:
* ``fit``: uses both train and validation datasets
* ``validate``: uses only the validation dataset
* ``test``: uses only the test dataset
* ``predict``: uses only the test dataset
self._datasets = {}
def train_dataloader(self) -> DataLoader:
"""Returns the train data loader."""
return torch.utils.data.DataLoader(
shuffle=(self._train_sampler is None),
def unshuffled_train_dataloader(self) -> DataLoader:
"""Returns the train data loader without shuffling."""
return torch.utils.data.DataLoader(
def val_dataloader(self) -> dict[str, DataLoader]:
"""Returns the validation data loader(s)"""
validation_loader_opts = {
"batch_size": self._chunk_size,
"shuffle": False,
"drop_last": self.drop_incomplete_batch,
"pin_memory": self.pin_memory,
return {
k: torch.utils.data.DataLoader(
self._datasets[k], **validation_loader_opts
for k in self._val_dataset_keys()
def test_dataloader(self) -> dict[str, DataLoader]:
"""Returns the test data loader(s)"""
return dict(
def predict_dataloader(self) -> dict[str, DataLoader]:
"""Returns the prediction data loader(s)"""
return self.test_dataloader()