Skip to content
Snippets Groups Projects
Commit eea2f306 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.datamodule] Implements ConcatDataModule (closes #16); Streamline types (see #24)

parent 6f2383c8
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -4,6 +4,7 @@
import collections
import functools
import itertools
import logging
import multiprocessing
import sys
......@@ -17,12 +18,14 @@ import torchvision.transforms
import tqdm
from .typing import (
ConcatDatabaseSplit,
DatabaseSplit,
DataLoader,
Dataset,
RawDataLoader,
Sample,
Transform,
TransformSequence,
)
logger = logging.getLogger(__name__)
......@@ -72,9 +75,9 @@ class _DelayedLoadingDataset(Dataset):
Parameters
----------
split
An iterable containing the raw dataset samples loaded from the database
splits.
raw_dataset
An iterable containing the raw dataset samples representing one of the
database split datasets.
loader
An object instance that can load samples and labels from storage.
......@@ -86,11 +89,11 @@ class _DelayedLoadingDataset(Dataset):
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_dataset: typing.Sequence[typing.Any],
loader: RawDataLoader,
transforms: typing.Sequence[Transform] = [],
transforms: TransformSequence = [],
):
self.split = split
self.raw_dataset = raw_dataset
self.loader = loader
self.transform = torchvision.transforms.Compose(transforms)
......@@ -105,14 +108,14 @@ class _DelayedLoadingDataset(Dataset):
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return [self.loader.label(k) for k in self.split]
return [self.loader.label(k) for k in self.raw_dataset]
def __getitem__(self, key: int) -> Sample:
tensor, metadata = self.loader.sample(self.split[key])
tensor, metadata = self.loader.sample(self.raw_dataset[key])
return self.transform(tensor), metadata
def __len__(self):
return len(self.split)
return len(self.raw_dataset)
def __iter__(self):
for x in range(len(self)):
......@@ -131,7 +134,7 @@ def _apply_loader_and_transforms(
----------
info
The sample information, as loaded from its split dictionary
The sample information, as loaded from its raw dataset dictionary
load
The raw-data loader function to use for loading the sample
......@@ -155,7 +158,7 @@ def _apply_loader_and_transforms(
class _CachedDataset(Dataset):
"""Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction
This dataset will load all samples from the raw dataset during construction
instead of delaying that to the indexing. Beyong raw-data-loading,
``transforms`` given upon construction contribute to the cached samples.
......@@ -163,9 +166,9 @@ class _CachedDataset(Dataset):
Parameters
----------
split
An iterable containing the raw dataset samples loaded from the database
splits.
raw_dataset
An iterable containing the raw dataset samples representing one of the
database split datasets.
loader
An object instance that can load samples and labels from storage.
......@@ -184,10 +187,10 @@ class _CachedDataset(Dataset):
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_dataset: typing.Sequence[typing.Any],
loader: RawDataLoader,
parallel: int = -1,
transforms: typing.Sequence[Transform] = [],
transforms: TransformSequence = [],
):
self.loader = functools.partial(
_apply_loader_and_transforms,
......@@ -197,14 +200,16 @@ class _CachedDataset(Dataset):
if parallel < 0:
self.data = [
self.loader(k) for k in tqdm.tqdm(split, unit="sample")
self.loader(k) for k in tqdm.tqdm(raw_dataset, unit="sample")
]
else:
instances = parallel or multiprocessing.cpu_count()
logger.info(f"Caching dataset using {instances} processes...")
with multiprocessing.Pool(instances) as p:
self.data = list(
tqdm.tqdm(p.imap(self.loader, split), total=len(split))
tqdm.tqdm(
p.imap(self.loader, raw_dataset), total=len(raw_dataset)
)
)
# Estimates memory occupance
......@@ -229,8 +234,42 @@ class _CachedDataset(Dataset):
return len(self.data)
def __iter__(self):
for x in range(len(self)):
yield self[x]
yield from self.data
class _ConcatDataset(Dataset):
"""A dataset that represents a concatenation of other cached or delayed
datasets.
Parameters
----------
datasets
An iterable over pre-instantiated datasets.
"""
def __init__(self, datasets: typing.Sequence[Dataset]):
self._datasets = datasets
self._indices = [
(i, j) # dataset relative position, sample relative position
for i in range(len(datasets))
for j in range(len(datasets[i]))
]
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
return list(itertools.chain(*[k.labels() for k in self._datasets]))
def __getitem__(self, key: int) -> Sample:
i, j = self._indices[key]
return self._datasets[i][j]
def __len__(self):
return sum([len(k) for k in self._datasets])
def __iter__(self):
for dataset in self._datasets:
yield from dataset
def _make_balanced_random_sampler(
......@@ -375,14 +414,15 @@ def _make_balanced_random_sampler(
)
class CachingDataModule(lightning.LightningDataModule):
"""A conveninent data module with CSV or JSON protocol loading, mini-
batching, parallelisation and caching, all in one.
class ConcatDataModule(lightning.LightningDataModule):
"""A conveninent data module with dictionary split loading, mini- batching,
parallelisation and caching, all in one.
Instances of this class load data-split (a.k.a. protocol) definitions for a
database, and can load the data from the disk. An optional caching
mechanism stores the data at associated CPU memory, which can improve data
serving while training and evaluating models.
Instances of this class can load and concatenate an arbitrary number of
data-split (a.k.a. protocol) definitions for (possibly disjoint) databases,
and can manage raw data-loading from disk. An optional caching mechanism
stores the data at associated CPU memory, which can improve data serving
while training and evaluating models.
This datamodule defines basic operations to handle data loading and
mini-batch handling within this package's framework. It can return
......@@ -390,31 +430,32 @@ class CachingDataModule(lightning.LightningDataModule):
prediction and testing conditions. Parallelisation is handled by a simple
input flag.
Users must implement the basic :py:meth:`setup` function, which is
parameterised by a single string enumeration containing: ``fit``,
``validate``, ``test``, or ``predict``.
Parameters
----------
database_split
A dictionary that contains string keys representing subset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to the ``sample_loader`` for loading
the sample data (and metadata) in memory. The objects represented may
be of any format (e.g. list, dictionary, etc), for as long as the
``sample_loader`` can properly handle it. To check the split and the
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
dictionary. Optional entries are ``validation``, and ``test``. Entries
named ``monitor-...`` will be considered extra subsets that do not
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
splits
A dictionary that contains string keys representing dataset names, and
values that are iterables over a 2-tuple containing an iterable over
arbitrary, user-configurable sample representations (potentially on
disk or permanent storage), and :py:class:`RawDataLoader` (or "sample")
loader objects, which concretely implement a mechanism to load such
samples in memory, from permanent storage.
loader
An object instance that can load samples and labels from storage.
Sample representations on permanent storage may be of any iterable
format (e.g. list, dictionary, etc.), for as long as the assigned
:py:class:`RawDataLoader` can properly handle it.
.. tip::
To check the split and the loader function works correctly, you may
use :py:func:`..dataset.check_database_split_loading`.
This class expects at least one entry called ``train`` to exist in the
input dictionary. Optional entries are ``validation``, and ``test``.
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
cache_samples
If set, then issue raw data loading during ``prepare_data()``, and
......@@ -486,8 +527,7 @@ class CachingDataModule(lightning.LightningDataModule):
def __init__(
self,
database_split: DatabaseSplit,
raw_data_loader: RawDataLoader,
splits: ConcatDatabaseSplit,
cache_samples: bool = False,
balance_sampler_by_class: bool = False,
batch_size: int = 1,
......@@ -499,8 +539,7 @@ class CachingDataModule(lightning.LightningDataModule):
self.set_chunk_size(batch_size, batch_chunk_count)
self.database_split = database_split
self.raw_data_loader = raw_data_loader
self.splits = splits
self.cache_samples = cache_samples
self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class
......@@ -581,6 +620,14 @@ class CachingDataModule(lightning.LightningDataModule):
If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making
sample across classes **and** datasets equitable.
.. warning::
This method does **NOT** balance the sampler per dataset, in case
multiple datasets compose the same training set. It only balances
samples acording to their ground-truth (labels). If you'd like to
have samples balanced per dataset, then implement your own data
module inheriting from this one.
"""
return self._train_sampler is not None
......@@ -661,32 +708,45 @@ class CachingDataModule(lightning.LightningDataModule):
f"Not re-instantiating it."
)
return
datasets: list[_CachedDataset | _DelayedLoadingDataset] = []
if self.cache_samples:
logger.info(
f"Loading dataset:`{name}` into memory (caching)."
f" Trade-off: CPU RAM: more | Disk: less"
)
self._datasets[name] = _CachedDataset(
self.database_split[name],
self.raw_data_loader,
self.parallel,
self.model_transforms,
)
for split, loader in self.splits[name]:
datasets.append(
_CachedDataset(
split,
loader,
self.parallel,
self.model_transforms,
)
)
else:
logger.info(
f"Loading dataset:`{name}` without caching."
f" Trade-off: CPU RAM: less | Disk: more"
)
self._datasets[name] = _DelayedLoadingDataset(
self.database_split[name],
self.raw_data_loader,
self.model_transforms,
)
for split, loader in self.splits[name]:
datasets.append(
_DelayedLoadingDataset(
split,
loader,
self.model_transforms,
)
)
if len(datasets) == 1:
self._datasets[name] = datasets[0]
else:
self._datasets[name] = _ConcatDataset(datasets)
def _val_dataset_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
k for k in self.splits.keys() if k.startswith("monitor-")
]
def setup(self, stage: str) -> None:
......@@ -727,7 +787,7 @@ class CachingDataModule(lightning.LightningDataModule):
self._setup_dataset("test")
elif stage == "predict":
for k in self.database_split:
for k in self.splits:
self._setup_dataset(k)
def teardown(self, stage: str) -> None:
......@@ -826,3 +886,53 @@ class CachingDataModule(lightning.LightningDataModule):
)
for k in self._datasets
}
class CachingDataModule(ConcatDataModule):
"""A simplified version of our data module for a single split.
Apart from construction, the behaviour of this data module is very similar
to its simpler counterpart, serving training, validation and test sets.
Parameters
----------
database_split
A dictionary that contains string keys representing dataset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to an unique :py:class:`RawDataLoader`
for loading the :py:class:`Sample` data (and metadata) in memory. It
therefore assumes the whole split is homogeneous and can be loaded in
the same way.
.. tip::
To check the split and the loader function works correctly, you may
use :py:func:`..dataset.check_database_split_loading`.
This class expects at least one entry called ``train`` to exist in the
input dictionary. Optional entries are ``validation``, and ``test``.
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
An object instance that can load samples and labels from storage.
**kwargs
List if named parameters matching those of
:py:class:`ConcatDataModule`, other than ``splits``.
"""
def __init__(
self,
database_split: DatabaseSplit,
raw_data_loader: RawDataLoader,
**kwargs,
):
splits = {k: [(v, raw_data_loader)] for k, v in database_split.items()}
super().__init__(
splits=splits,
**kwargs,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""
import collections.abc
......@@ -51,8 +50,23 @@ TransformSequence = typing.Sequence[Transform]
DatabaseSplit = collections.abc.Mapping[str, typing.Sequence[typing.Any]]
"""The definition of a database split.
A database split maps subset names to sequences of objects that, through
RawDataLoader's eventually become Samples in the processing pipeline.
A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline.
"""
ConcatDatabaseSplit = collections.abc.Mapping[
str,
typing.Sequence[tuple[typing.Sequence[typing.Any], RawDataLoader]],
]
"""The definition of a complex database split composed of several other splits.
A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline. Objects of this subtype
allow the construction of complex splits composed of cannibalized parts
of other splits. Each split may be assigned a different
:py:class:`RawDataLoader`.
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment