Skip to content
Snippets Groups Projects
datamodule.py 26.29 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later

import collections
import logging
import multiprocessing
import sys
import typing

import lightning
import torch
import torch.backends
import torch.utils.data
import torchvision.transforms
import tqdm

from .typing import (
    DatabaseSplit,
    DataLoader,
    Dataset,
    RawDataLoader,
    Sample,
    Transform,
)

logger = logging.getLogger(__name__)


def _setup_dataloader_multiproc_parameters(
    parallel: int,
) -> dict[str, typing.Any]:
    """Returns a dictionary containing pytorch arguments to be used in data
    loaders.

    It sets the parameter ``num_workers`` to match the expected pytorch
    representation.  For macOS machines, it also sets the
    ``multiprocessing_context`` to use ``spawn`` instead of the default.

    The mapping between the command-line interface ``parallel`` setting works
    like this:

    .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
       :widths: 15 15 70
       :header-rows: 1

       * - CLI ``parallel``
         - :py:class:`torch.utils.data.DataLoader` ``kwargs``
         - Comments
       * - ``<0``
         - 0
         - Disables multiprocessing entirely, executes everything within the
           same processing context
       * - ``0``
         - :py:func:`multiprocessing.cpu_count`
         - Runs mini-batch data loading on as many external processes as CPUs
           available in the current machine
       * - ``>=1``
         - ``parallel``
         - Runs mini-batch data loading on as many external processes as set on
           ``parallel``
    """

    retval: dict[str, typing.Any] = dict()
    if parallel < 0:
        retval["num_workers"] = 0
    else:
        retval["num_workers"] = parallel or multiprocessing.cpu_count()

    if retval["num_workers"] > 0 and sys.platform == "darwin":
        retval["multiprocessing_context"] = multiprocessing.get_context("spawn")

    return retval


class _DelayedLoadingDataset(Dataset):
    """A list that loads its samples on demand.

    This list mimics a pytorch Dataset, except raw data loading is done
    on-the-fly, as the samples are requested through the bracket operator.


    Parameters
    ----------

    split
        An iterable containing the raw dataset samples loaded from the database
        splits.

    loader
        An object instance that can load samples and labels from storage.

    transforms
        A set of transforms that should be applied on-the-fly for this dataset,
        to fit the output of the raw-data-loader to the model of interest.
    """

    def __init__(
        self,
        split: typing.Sequence[typing.Any],
        loader: RawDataLoader,
        transforms: typing.Sequence[Transform] = [],
    ):
        self.split = split
        self.loader = loader
        self.transform = torchvision.transforms.Compose(transforms)

    def labels(self) -> list[int]:
        """Returns the integer labels for all samples in the dataset."""
        return [self.loader.label(k) for k in self.split]

    def __getitem__(self, key: int) -> Sample:
        tensor, metadata = self.loader.sample(self.split[key])
        return self.transform(tensor), metadata

    def __len__(self):
        return len(self.split)

    def __iter__(self):
        for x in range(len(self)):
            yield self[x]


class _CachedDataset(Dataset):
    """Basically, a list of preloaded samples.

    This dataset will load all samples from the split during construction
    instead of delaying that to the indexing.  Beyong raw-data-loading,
    ``transforms`` given upon construction contribute to the cached samples.


    Parameters
    ----------

    split
        An iterable containing the raw dataset samples loaded from the database
        splits.

    loader
        An object instance that can load samples and labels from storage.

    parallel
        Use multiprocessing for data loading: if set to -1 (default), disables
        multiprocessing data loading.  Set to 0 to enable as many data loading
        instances as processing cores as available in the system.  Set to >= 1
        to enable that many multiprocessing instances for data loading.

    transforms
        A set of transforms that should be applied to the cached samples for
        this dataset, to fit the output of the raw-data-loader to the model of
        interest.
    """

    def __init__(
        self,
        split: typing.Sequence[typing.Any],
        loader: RawDataLoader,
        parallel: int = -1,
        transforms: typing.Sequence[Transform] = [],
    ):
        self.transform = torchvision.transforms.Compose(transforms)

        if parallel < 0:
            self.data = [
                loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
            ]
        else:
            instances = parallel or multiprocessing.cpu_count()
            logger.info(f"Caching dataset using {instances} processes...")
            with multiprocessing.Pool(instances) as p:
                self.data = list(
                    tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
                )

    def labels(self) -> list[int]:
        """Returns the integer labels for all samples in the dataset."""
        return [k[1]["label"] for k in self.data]

    def __getitem__(self, key: int) -> Sample:
        tensor, metadata = self.data[key]
        return self.transform(tensor), metadata

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        for x in range(len(self)):
            yield self[x]


def _make_balanced_random_sampler(
    dataset: Dataset,
    target: str = "label",
) -> torch.utils.data.WeightedRandomSampler:
    """Generates a pytorch sampler that samples according to class
    probabilities.

    This function takes as input a torch Dataset, and computes the weights to
    balance each class in the dataset, and the datasets themselves if one
    passes a :py:class:`torch.utils.data.ConcatDataset`.

    In this implementation, we balance **both** class and dataset-origin
    probabilities, what you expect for a truly *equitable* random sampler.

    Take this example for illustration:

    * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
    * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1

    So:

    | Dataset | Target | Samples | Weight | Normalised weight |
    +---------+--------+---------+--------+-------------------+
    |    1    |   0    |    9    |  1/9   |      1/36         |
    |    1    |   1    |    1    |  1/1   |      1/4          |
    |    2    |   0    |    3    |  1/3   |      1/12         |
    |    2    |   1    |    3    |  1/3   |      1/12         |

    Legend:

    * Weight: the weights computed by this method
    * Normalised weight: the weight per sample used by the random sampler,
      after normalising the weights by the sum of all weights in the
      concatenated dataset, such that the sum of all normalized weights times
      the number of samples is 1.

    The properties of this algorithm are as follows:

    1. The probability of picking a sample from any target is the same (0.5 in
       this case).  To verify this, notice that the probability of picking a
       sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
    2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is
       3 times higher than those from Dataset 1.  As there are 3 times less
       samples in Dataset 2 with ``target=0``, this makes choosing samples from
       Dataset 1 proportionally less likely.
    3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is
       3 times lower than those from Dataset 1.  As there are 3 times less
       samples in Dataset 1 with ``target=1``, this makes choosing samples from
       Dataset 2 proportionally less likely.

    This function assumes targets are stored on a dictionary entry named
    ``target`` inside the metadata information for the :py:type:``Sample``, and
    that its value is integer.

    We then instantiate a pytorch sampler using the inverse probabilities (the
    more samples of a class, the less likely it becomes to be sampled.


    Parameters
    ----------

    dataset
        An instance of torch Dataset.
        :py:class:`torch.utils.data.ConcatDataset` are supported.

    target
        The name of a metadata key pointing to an integer property that allows
        balancing the dataset.


    Returns
    -------

    sampler
        A sampler, to be used in a dataloader equipped with the same dataset
        used to calculate the relative sample weights.


    Raises
    ------

    RuntimeError
        If requested to balance a dataset (single, not-concatenated) without an
        existing target.
    """

    def _calculate_weights(targets: list[int]) -> list[float]:
        counts = collections.Counter(targets)
        weights = {k: 1.0 / v for k, v in counts.items()}
        return [weights[k] for k in targets]

    if isinstance(dataset, torch.utils.data.ConcatDataset):
        # There are two possible cases: targets/no-targets
        metadata_example = dataset.datasets[0][0][1]
        if target in metadata_example and isinstance(
            metadata_example[target], int
        ):
            # there are integer targets, let's balance with those
            logger.info(
                f"Balancing sample selection probabilities **and** "
                f"concatenated-datasets using metadata targets `{target}`"
            )
            targets = [
                k
                for ds in dataset.datasets
                for k in typing.cast(Dataset, ds).labels()
            ]
            weights = _calculate_weights(targets)
        else:
            logger.warning(
                f"Balancing samples **and** concatenated-datasets "
                f"WITHOUT metadata targets (`{target}` not available)"
            )
            weights = [
                k
                for ds in dataset.datasets
                for k in len(typing.cast(typing.Sized, ds))
                * [1.0 / len(typing.cast(typing.Sized, ds))]
            ]

            pass

    else:
        metadata_example = dataset[0][1]
        if target in metadata_example and isinstance(
            metadata_example[target], int
        ):
            logger.info(
                f"Balancing samples from dataset using metadata "
                f"targets `{target}`"
            )
            weights = _calculate_weights(dataset.labels())
        else:
            raise RuntimeError(
                f"Cannot balance samples without metadata targets `{target}`"
            )

    return torch.utils.data.WeightedRandomSampler(
        weights, len(weights), replacement=True
    )


class CachingDataModule(lightning.LightningDataModule):
    """A conveninent data module with CSV or JSON protocol loading, mini-
    batching, parallelisation and caching, all in one.

    Instances of this class load data-split (a.k.a. protocol) definitions for a
    database, and can load the data from the disk.  An optional caching
    mechanism stores the data at associated CPU memory, which can improve data
    serving while training and evaluating models.

    This datamodule defines basic operations to handle data loading and
    mini-batch handling within this package's framework.  It can return
    :py:class:`torch.utils.data.DataLoader` objects for training, validation,
    prediction and testing conditions.  Parallelisation is handled by a simple
    input flag.

    Users must implement the basic :py:meth:`setup` function, which is
    parameterised by a single string enumeration containing: ``fit``,
    ``validate``, ``test``, or ``predict``.


    Parameters
    ----------

    database_split
        A dictionary that contains string keys representing subset names, and
        values that are iterables over sample representations (potentially on
        disk).  These objects are passed to the ``sample_loader`` for loading
        the sample data (and metadata) in memory.  The objects represented may
        be of any format (e.g. list, dictionary, etc), for as long as the
        ``sample_loader`` can properly handle it.  To check the split and the
        loader function works correctly, you may use
        :py:func:`..dataset.check_database_split_loading`.  As is, this class
        expects at least one entry called ``train`` to exist in the input
        dictionary.  Optional entries are ``validation``, and ``test``. Entries
        named ``monitor-...`` will be considered extra subsets that do not
        influence any early stop criteria during training, and are just
        monitored beyond the ``validation`` dataset.

    loader
        An object instance that can load samples and labels from storage.

    cache_samples
        If set, then issue raw data loading during ``prepare_data()``, and
        serves samples from CPU memory.  Otherwise, loads samples from disk on
        demand. Running from CPU memory will offer increased speeds in exchange
        for CPU memory.  Sufficient CPU memory must be available before you set
        this attribute to ``True``.  It is typicall useful for relatively small
        datasets.

    balance_sampler_by_class
        If set, then modifies the random sampler used during training and
        validation to balance sample picking probability, making sample
        across classes **and** datasets equitable.

    model_transforms
        A list of transforms (torch modules) that will be applied after
        raw-data-loading, and just before data is fed into the model or
        eventual data-augmentation transformations for all data loaders
        produced by this data module.  This part of the pipeline receives data
        as output by the raw-data-loader, or model-related transforms (e.g.
        resize adaptions), if any is specified.

    batch_size
        Number of samples in every **training** batch (this parameter affects
        memory requirements for the network).  If the number of samples in the
        batch is larger than the total number of samples available for
        training, this value is truncated.  If this number is smaller, then
        batches of the specified size are created and fed to the network  until
        there are no more new samples to feed (epoch is finished).  If the
        total number of training samples is not a multiple of the batch-size,
        the last batch will be smaller than the first, unless
        ``drop_incomplete_batch`` is set to ``true``, in which case this batch
        is not used.

    batch_chunk_count
        Number of chunks in every batch (this parameter affects memory
        requirements for the network). The number of samples loaded for every
        iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
        needs to be divisible by ``batch_chunk_count``, otherwise an error will
        be raised. This parameter is used to reduce number of samples loaded in
        each iteration, in order to reduce the memory usage in exchange for
        processing time (more iterations). This is specially interesting whe
        one is running with GPUs with limited RAM. The default of 1 forces the
        whole batch to be processed at once. Otherwise the batch is broken into
        batch-chunk-count pieces, and gradients are accumulated to complete
        each batch.

    drop_incomplete_batch
        If set, then may drop the last batch in an epoch, in case it is
        incomplete.  If you set this option, you should also consider
        increasing the total number of epochs of training, as the total number
        of training steps may be reduced.

    parallel
        Use multiprocessing for data loading: if set to -1 (default), disables
        multiprocessing data loading.  Set to 0 to enable as many data loading
        instances as processing cores as available in the system.  Set to >= 1
        to enable that many multiprocessing instances for data loading.
    """

    DatasetDictionary = dict[str, Dataset]

    def __init__(
        self,
        database_split: DatabaseSplit,
        raw_data_loader: RawDataLoader,
        cache_samples: bool = False,
        balance_sampler_by_class: bool = False,
        model_transforms: list[Transform] = [],
        batch_size: int = 1,
        batch_chunk_count: int = 1,
        drop_incomplete_batch: bool = False,
        parallel: int = -1,
    ):
        super().__init__()

        self.set_chunk_size(batch_size, batch_chunk_count)

        self.database_split = database_split
        self.raw_data_loader = raw_data_loader
        self.cache_samples = cache_samples
        self._train_sampler = None
        self.balance_sampler_by_class = balance_sampler_by_class
        self.model_transforms = model_transforms

        self.drop_incomplete_batch = drop_incomplete_batch
        self.parallel = parallel  # immutable, otherwise would need to call

        self.pin_memory = (
            torch.cuda.is_available() or torch.backends.mps.is_available()
        )  # should only be true if GPU available and using it

        # datasets that have been setup() for the current stage
        self._datasets: CachingDataModule.DatasetDictionary = {}

    @property
    def parallel(self) -> int:
        """Whether to use multiprocessing for data loading.

        Use multiprocessing for data loading: if set to -1 (default),
        disables multiprocessing data loading.  Set to 0 to enable as
        many data loading instances as processing cores as available in
        the system.  Set to >= 1 to enable that many multiprocessing
        instances for data loading.
        """
        return self._parallel

    @parallel.setter
    def parallel(self, value: int) -> None:
        self._parallel = value
        self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
            value
        )
        # datasets that have been setup() for the current stage
        self._datasets = {}

    @property
    def balance_sampler_by_class(self):
        """Whether to balance samples across labels/datasets.

        If set, then modifies the random sampler used during training
        and validation to balance sample picking probability, making
        sample across classes **and** datasets equitable.
        """
        return self._train_sampler is not None

    @balance_sampler_by_class.setter
    def balance_sampler_by_class(self, value: bool):
        if value:
            if "train" not in self._datasets:
                self._setup_dataset("train")
            self._train_sampler = _make_balanced_random_sampler(
                self._datasets["train"]
            )
        else:
            self._train_sampler = None

    def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
        """Coherently sets the batch-chunk-size after validation.

        Parameters
        ----------

        batch_size
            Number of samples in every **training** batch (this parameter affects
            memory requirements for the network).  If the number of samples in the
            batch is larger than the total number of samples available for
            training, this value is truncated.  If this number is smaller, then
            batches of the specified size are created and fed to the network  until
            there are no more new samples to feed (epoch is finished).  If the
            total number of training samples is not a multiple of the batch-size,
            the last batch will be smaller than the first, unless
            ``drop_incomplete_batch`` is set to ``true``, in which case this batch
            is not used.

        batch_chunk_count
            Number of chunks in every batch (this parameter affects memory
            requirements for the network). The number of samples loaded for every
            iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
            needs to be divisible by ``batch_chunk_count``, otherwise an error will
            be raised. This parameter is used to reduce number of samples loaded in
            each iteration, in order to reduce the memory usage in exchange for
            processing time (more iterations). This is specially interesting whe
            one is running with GPUs with limited RAM. The default of 1 forces the
            whole batch to be processed at once. Otherwise the batch is broken into
            batch-chunk-count pieces, and gradients are accumulated to complete
            each batch.
        """

        # validation
        if batch_size % batch_chunk_count != 0:
            raise RuntimeError(
                f"batch_size ({batch_size}) must be divisible by "
                f"batch_chunk_size ({batch_chunk_count})."
            )

        self._batch_size = batch_size
        self._batch_chunk_count = batch_chunk_count
        self._chunk_size = self._batch_size // self._batch_chunk_count

    def _setup_dataset(self, name: str) -> None:
        """Sets-up a single dataset from the input data split.

        Parameters
        ----------

        name
            Name of the dataset to setup.
        """

        if name in self._datasets:
            logger.info(
                f"Dataset `{name}` is already setup. "
                f"Not re-instantiating it."
            )
            return
        if self.cache_samples:
            logger.info(
                f"Loading dataset:`{name}` into memory (caching)."
                f" Trade-off: CPU RAM: more | Disk: less"
            )
            self._datasets[name] = _CachedDataset(
                self.database_split[name],
                self.raw_data_loader,
                self.parallel,
                self.model_transforms,
            )
        else:
            logger.info(
                f"Loading dataset:`{name}` without caching."
                f" Trade-off: CPU RAM: less | Disk: more"
            )
            self._datasets[name] = _DelayedLoadingDataset(
                self.database_split[name],
                self.raw_data_loader,
                self.model_transforms,
            )

    def _val_dataset_keys(self) -> list[str]:
        """Returns list of validation dataset names."""
        return ["validation"] + [
            k for k in self.database_split.keys() if k.startswith("monitor-")
        ]

    def setup(self, stage: str) -> None:
        """Sets up datasets for different tasks on the pipeline.

        This method should setup (load, pre-process, etc) all datasets required
        for a particular ``stage`` (fit, validate, test, predict), and keep
        them ready to be used on one of the `_dataloader()` functions that are
        pertinent for such stage.

        If you have set ``cache_samples``, samples are loaded at this stage and
        cached in memory.


        Parameters
        ----------

        stage
            Name of the stage to which the setup is applicable.  Can be one of
            ``fit``, ``validate``, ``test`` or ``predict``.  Each stage
            typically uses the following data loaders:

            * ``fit``: uses both train and validation datasets
            * ``validate``: uses only the validation dataset
            * ``test``: uses only the test dataset
            * ``predict``: uses only the test dataset
        """

        if stage == "fit":
            for k in ["train"] + self._val_dataset_keys():
                self._setup_dataset(k)

        elif stage == "validate":
            for k in self._val_dataset_keys():
                self._setup_dataset(k)

        elif stage == "test":
            self._setup_dataset("test")

        elif stage == "predict":
            self._setup_dataset("test")

    def teardown(self, stage: str) -> None:
        """Unset-up datasets for different tasks on the pipeline.

        This method unsets (unload, remove from memory, etc) all datasets required
        for a particular ``stage`` (fit, validate, test, predict).

        If you have set ``cache_samples``, samples are loaded, this may
        effectivley release all the associated memory.


        Parameters
        ----------

        stage
            Name of the stage to which the teardown is applicable.  Can be one of
            ``fit``, ``validate``, ``test`` or ``predict``.  Each stage
            typically uses the following data loaders:

            * ``fit``: uses both train and validation datasets
            * ``validate``: uses only the validation dataset
            * ``test``: uses only the test dataset
            * ``predict``: uses only the test dataset
        """

        self._datasets = {}

    def train_dataloader(self) -> DataLoader:
        """Returns the train data loader."""

        return torch.utils.data.DataLoader(
            self._datasets["train"],
            shuffle=(self._train_sampler is None),
            batch_size=self._chunk_size,
            drop_last=self.drop_incomplete_batch,
            pin_memory=self.pin_memory,
            sampler=self._train_sampler,
            **self._dataloader_multiproc,
        )

    def unshuffled_train_dataloader(self) -> DataLoader:
        """Returns the train data loader without shuffling."""

        return torch.utils.data.DataLoader(
            self._datasets["train"],
            shuffle=False,
            batch_size=self._chunk_size,
            drop_last=False,
            **self._dataloader_multiproc,
        )

    def val_dataloader(self) -> dict[str, DataLoader]:
        """Returns the validation data loader(s)"""

        validation_loader_opts = {
            "batch_size": self._chunk_size,
            "shuffle": False,
            "drop_last": self.drop_incomplete_batch,
            "pin_memory": self.pin_memory,
        }
        validation_loader_opts.update(self._dataloader_multiproc)

        return {
            k: torch.utils.data.DataLoader(
                self._datasets[k], **validation_loader_opts
            )
            for k in self._val_dataset_keys()
        }

    def test_dataloader(self) -> dict[str, DataLoader]:
        """Returns the test data loader(s)"""

        return dict(
            test=torch.utils.data.DataLoader(
                self._datasets["test"],
                batch_size=self._chunk_size,
                shuffle=False,
                drop_last=self.drop_incomplete_batch,
                pin_memory=self.pin_memory,
                **self._dataloader_multiproc,
            )
        )

    def predict_dataloader(self) -> dict[str, DataLoader]:
        """Returns the prediction data loader(s)"""

        return self.test_dataloader()