Skip to content
Snippets Groups Projects
Commit eea2f306 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.datamodule] Implements ConcatDataModule (closes #16); Streamline types (see #24)

parent 6f2383c8
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import collections import collections
import functools import functools
import itertools
import logging import logging
import multiprocessing import multiprocessing
import sys import sys
...@@ -17,12 +18,14 @@ import torchvision.transforms ...@@ -17,12 +18,14 @@ import torchvision.transforms
import tqdm import tqdm
from .typing import ( from .typing import (
ConcatDatabaseSplit,
DatabaseSplit, DatabaseSplit,
DataLoader, DataLoader,
Dataset, Dataset,
RawDataLoader, RawDataLoader,
Sample, Sample,
Transform, Transform,
TransformSequence,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -72,9 +75,9 @@ class _DelayedLoadingDataset(Dataset): ...@@ -72,9 +75,9 @@ class _DelayedLoadingDataset(Dataset):
Parameters Parameters
---------- ----------
split raw_dataset
An iterable containing the raw dataset samples loaded from the database An iterable containing the raw dataset samples representing one of the
splits. database split datasets.
loader loader
An object instance that can load samples and labels from storage. An object instance that can load samples and labels from storage.
...@@ -86,11 +89,11 @@ class _DelayedLoadingDataset(Dataset): ...@@ -86,11 +89,11 @@ class _DelayedLoadingDataset(Dataset):
def __init__( def __init__(
self, self,
split: typing.Sequence[typing.Any], raw_dataset: typing.Sequence[typing.Any],
loader: RawDataLoader, loader: RawDataLoader,
transforms: typing.Sequence[Transform] = [], transforms: TransformSequence = [],
): ):
self.split = split self.raw_dataset = raw_dataset
self.loader = loader self.loader = loader
self.transform = torchvision.transforms.Compose(transforms) self.transform = torchvision.transforms.Compose(transforms)
...@@ -105,14 +108,14 @@ class _DelayedLoadingDataset(Dataset): ...@@ -105,14 +108,14 @@ class _DelayedLoadingDataset(Dataset):
def labels(self) -> list[int]: def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset.""" """Returns the integer labels for all samples in the dataset."""
return [self.loader.label(k) for k in self.split] return [self.loader.label(k) for k in self.raw_dataset]
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.split[key]) tensor, metadata = self.loader.sample(self.raw_dataset[key])
return self.transform(tensor), metadata return self.transform(tensor), metadata
def __len__(self): def __len__(self):
return len(self.split) return len(self.raw_dataset)
def __iter__(self): def __iter__(self):
for x in range(len(self)): for x in range(len(self)):
...@@ -131,7 +134,7 @@ def _apply_loader_and_transforms( ...@@ -131,7 +134,7 @@ def _apply_loader_and_transforms(
---------- ----------
info info
The sample information, as loaded from its split dictionary The sample information, as loaded from its raw dataset dictionary
load load
The raw-data loader function to use for loading the sample The raw-data loader function to use for loading the sample
...@@ -155,7 +158,7 @@ def _apply_loader_and_transforms( ...@@ -155,7 +158,7 @@ def _apply_loader_and_transforms(
class _CachedDataset(Dataset): class _CachedDataset(Dataset):
"""Basically, a list of preloaded samples. """Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction This dataset will load all samples from the raw dataset during construction
instead of delaying that to the indexing. Beyong raw-data-loading, instead of delaying that to the indexing. Beyong raw-data-loading,
``transforms`` given upon construction contribute to the cached samples. ``transforms`` given upon construction contribute to the cached samples.
...@@ -163,9 +166,9 @@ class _CachedDataset(Dataset): ...@@ -163,9 +166,9 @@ class _CachedDataset(Dataset):
Parameters Parameters
---------- ----------
split raw_dataset
An iterable containing the raw dataset samples loaded from the database An iterable containing the raw dataset samples representing one of the
splits. database split datasets.
loader loader
An object instance that can load samples and labels from storage. An object instance that can load samples and labels from storage.
...@@ -184,10 +187,10 @@ class _CachedDataset(Dataset): ...@@ -184,10 +187,10 @@ class _CachedDataset(Dataset):
def __init__( def __init__(
self, self,
split: typing.Sequence[typing.Any], raw_dataset: typing.Sequence[typing.Any],
loader: RawDataLoader, loader: RawDataLoader,
parallel: int = -1, parallel: int = -1,
transforms: typing.Sequence[Transform] = [], transforms: TransformSequence = [],
): ):
self.loader = functools.partial( self.loader = functools.partial(
_apply_loader_and_transforms, _apply_loader_and_transforms,
...@@ -197,14 +200,16 @@ class _CachedDataset(Dataset): ...@@ -197,14 +200,16 @@ class _CachedDataset(Dataset):
if parallel < 0: if parallel < 0:
self.data = [ self.data = [
self.loader(k) for k in tqdm.tqdm(split, unit="sample") self.loader(k) for k in tqdm.tqdm(raw_dataset, unit="sample")
] ]
else: else:
instances = parallel or multiprocessing.cpu_count() instances = parallel or multiprocessing.cpu_count()
logger.info(f"Caching dataset using {instances} processes...") logger.info(f"Caching dataset using {instances} processes...")
with multiprocessing.Pool(instances) as p: with multiprocessing.Pool(instances) as p:
self.data = list( self.data = list(
tqdm.tqdm(p.imap(self.loader, split), total=len(split)) tqdm.tqdm(
p.imap(self.loader, raw_dataset), total=len(raw_dataset)
)
) )
# Estimates memory occupance # Estimates memory occupance
...@@ -229,8 +234,42 @@ class _CachedDataset(Dataset): ...@@ -229,8 +234,42 @@ class _CachedDataset(Dataset):
return len(self.data) return len(self.data)
def __iter__(self): def __iter__(self):
for x in range(len(self)): yield from self.data
yield self[x]
class _ConcatDataset(Dataset):
"""A dataset that represents a concatenation of other cached or delayed
datasets.
Parameters
----------
datasets
An iterable over pre-instantiated datasets.
"""
def __init__(self, datasets: typing.Sequence[Dataset]):
self._datasets = datasets
self._indices = [
(i, j) # dataset relative position, sample relative position
for i in range(len(datasets))
for j in range(len(datasets[i]))
]
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return list(itertools.chain(*[k.labels() for k in self._datasets]))
def __getitem__(self, key: int) -> Sample:
i, j = self._indices[key]
return self._datasets[i][j]
def __len__(self):
return sum([len(k) for k in self._datasets])
def __iter__(self):
for dataset in self._datasets:
yield from dataset
def _make_balanced_random_sampler( def _make_balanced_random_sampler(
...@@ -375,14 +414,15 @@ def _make_balanced_random_sampler( ...@@ -375,14 +414,15 @@ def _make_balanced_random_sampler(
) )
class CachingDataModule(lightning.LightningDataModule): class ConcatDataModule(lightning.LightningDataModule):
"""A conveninent data module with CSV or JSON protocol loading, mini- """A conveninent data module with dictionary split loading, mini- batching,
batching, parallelisation and caching, all in one. parallelisation and caching, all in one.
Instances of this class load data-split (a.k.a. protocol) definitions for a Instances of this class can load and concatenate an arbitrary number of
database, and can load the data from the disk. An optional caching data-split (a.k.a. protocol) definitions for (possibly disjoint) databases,
mechanism stores the data at associated CPU memory, which can improve data and can manage raw data-loading from disk. An optional caching mechanism
serving while training and evaluating models. stores the data at associated CPU memory, which can improve data serving
while training and evaluating models.
This datamodule defines basic operations to handle data loading and This datamodule defines basic operations to handle data loading and
mini-batch handling within this package's framework. It can return mini-batch handling within this package's framework. It can return
...@@ -390,31 +430,32 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -390,31 +430,32 @@ class CachingDataModule(lightning.LightningDataModule):
prediction and testing conditions. Parallelisation is handled by a simple prediction and testing conditions. Parallelisation is handled by a simple
input flag. input flag.
Users must implement the basic :py:meth:`setup` function, which is
parameterised by a single string enumeration containing: ``fit``,
``validate``, ``test``, or ``predict``.
Parameters Parameters
---------- ----------
database_split splits
A dictionary that contains string keys representing subset names, and A dictionary that contains string keys representing dataset names, and
values that are iterables over sample representations (potentially on values that are iterables over a 2-tuple containing an iterable over
disk). These objects are passed to the ``sample_loader`` for loading arbitrary, user-configurable sample representations (potentially on
the sample data (and metadata) in memory. The objects represented may disk or permanent storage), and :py:class:`RawDataLoader` (or "sample")
be of any format (e.g. list, dictionary, etc), for as long as the loader objects, which concretely implement a mechanism to load such
``sample_loader`` can properly handle it. To check the split and the samples in memory, from permanent storage.
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
dictionary. Optional entries are ``validation``, and ``test``. Entries
named ``monitor-...`` will be considered extra subsets that do not
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
loader Sample representations on permanent storage may be of any iterable
An object instance that can load samples and labels from storage. format (e.g. list, dictionary, etc.), for as long as the assigned
:py:class:`RawDataLoader` can properly handle it.
.. tip::
To check the split and the loader function works correctly, you may
use :py:func:`..dataset.check_database_split_loading`.
This class expects at least one entry called ``train`` to exist in the
input dictionary. Optional entries are ``validation``, and ``test``.
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
cache_samples cache_samples
If set, then issue raw data loading during ``prepare_data()``, and If set, then issue raw data loading during ``prepare_data()``, and
...@@ -486,8 +527,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -486,8 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
def __init__( def __init__(
self, self,
database_split: DatabaseSplit, splits: ConcatDatabaseSplit,
raw_data_loader: RawDataLoader,
cache_samples: bool = False, cache_samples: bool = False,
balance_sampler_by_class: bool = False, balance_sampler_by_class: bool = False,
batch_size: int = 1, batch_size: int = 1,
...@@ -499,8 +539,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -499,8 +539,7 @@ class CachingDataModule(lightning.LightningDataModule):
self.set_chunk_size(batch_size, batch_chunk_count) self.set_chunk_size(batch_size, batch_chunk_count)
self.database_split = database_split self.splits = splits
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples self.cache_samples = cache_samples
self._train_sampler = None self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class self.balance_sampler_by_class = balance_sampler_by_class
...@@ -581,6 +620,14 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -581,6 +620,14 @@ class CachingDataModule(lightning.LightningDataModule):
If set, then modifies the random sampler used during training If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making and validation to balance sample picking probability, making
sample across classes **and** datasets equitable. sample across classes **and** datasets equitable.
.. warning::
This method does **NOT** balance the sampler per dataset, in case
multiple datasets compose the same training set. It only balances
samples acording to their ground-truth (labels). If you'd like to
have samples balanced per dataset, then implement your own data
module inheriting from this one.
""" """
return self._train_sampler is not None return self._train_sampler is not None
...@@ -661,32 +708,45 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -661,32 +708,45 @@ class CachingDataModule(lightning.LightningDataModule):
f"Not re-instantiating it." f"Not re-instantiating it."
) )
return return
datasets: list[_CachedDataset | _DelayedLoadingDataset] = []
if self.cache_samples: if self.cache_samples:
logger.info( logger.info(
f"Loading dataset:`{name}` into memory (caching)." f"Loading dataset:`{name}` into memory (caching)."
f" Trade-off: CPU RAM: more | Disk: less" f" Trade-off: CPU RAM: more | Disk: less"
) )
self._datasets[name] = _CachedDataset( for split, loader in self.splits[name]:
self.database_split[name], datasets.append(
self.raw_data_loader, _CachedDataset(
self.parallel, split,
self.model_transforms, loader,
) self.parallel,
self.model_transforms,
)
)
else: else:
logger.info( logger.info(
f"Loading dataset:`{name}` without caching." f"Loading dataset:`{name}` without caching."
f" Trade-off: CPU RAM: less | Disk: more" f" Trade-off: CPU RAM: less | Disk: more"
) )
self._datasets[name] = _DelayedLoadingDataset( for split, loader in self.splits[name]:
self.database_split[name], datasets.append(
self.raw_data_loader, _DelayedLoadingDataset(
self.model_transforms, split,
) loader,
self.model_transforms,
)
)
if len(datasets) == 1:
self._datasets[name] = datasets[0]
else:
self._datasets[name] = _ConcatDataset(datasets)
def _val_dataset_keys(self) -> list[str]: def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names.""" """Returns list of validation dataset names."""
return ["validation"] + [ return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-") k for k in self.splits.keys() if k.startswith("monitor-")
] ]
def setup(self, stage: str) -> None: def setup(self, stage: str) -> None:
...@@ -727,7 +787,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -727,7 +787,7 @@ class CachingDataModule(lightning.LightningDataModule):
self._setup_dataset("test") self._setup_dataset("test")
elif stage == "predict": elif stage == "predict":
for k in self.database_split: for k in self.splits:
self._setup_dataset(k) self._setup_dataset(k)
def teardown(self, stage: str) -> None: def teardown(self, stage: str) -> None:
...@@ -826,3 +886,53 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -826,3 +886,53 @@ class CachingDataModule(lightning.LightningDataModule):
) )
for k in self._datasets for k in self._datasets
} }
class CachingDataModule(ConcatDataModule):
"""A simplified version of our data module for a single split.
Apart from construction, the behaviour of this data module is very similar
to its simpler counterpart, serving training, validation and test sets.
Parameters
----------
database_split
A dictionary that contains string keys representing dataset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to an unique :py:class:`RawDataLoader`
for loading the :py:class:`Sample` data (and metadata) in memory. It
therefore assumes the whole split is homogeneous and can be loaded in
the same way.
.. tip::
To check the split and the loader function works correctly, you may
use :py:func:`..dataset.check_database_split_loading`.
This class expects at least one entry called ``train`` to exist in the
input dictionary. Optional entries are ``validation``, and ``test``.
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
An object instance that can load samples and labels from storage.
**kwargs
List if named parameters matching those of
:py:class:`ConcatDataModule`, other than ``splits``.
"""
def __init__(
self,
database_split: DatabaseSplit,
raw_data_loader: RawDataLoader,
**kwargs,
):
splits = {k: [(v, raw_data_loader)] for k, v in database_split.items()}
super().__init__(
splits=splits,
**kwargs,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code.""" """Defines most common types used in code."""
import collections.abc import collections.abc
...@@ -51,8 +50,23 @@ TransformSequence = typing.Sequence[Transform] ...@@ -51,8 +50,23 @@ TransformSequence = typing.Sequence[Transform]
DatabaseSplit = collections.abc.Mapping[str, typing.Sequence[typing.Any]] DatabaseSplit = collections.abc.Mapping[str, typing.Sequence[typing.Any]]
"""The definition of a database split. """The definition of a database split.
A database split maps subset names to sequences of objects that, through A database split maps dataset (subset) names to sequences of objects
RawDataLoader's eventually become Samples in the processing pipeline. that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline.
"""
ConcatDatabaseSplit = collections.abc.Mapping[
str,
typing.Sequence[tuple[typing.Sequence[typing.Any], RawDataLoader]],
]
"""The definition of a complex database split composed of several other splits.
A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline. Objects of this subtype
allow the construction of complex splits composed of cannibalized parts
of other splits. Each split may be assigned a different
:py:class:`RawDataLoader`.
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment