-
André Anjos authoredAndré Anjos authored
typing.py 2.80 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""
import collections.abc
import typing
import torch
import torch.utils.data
Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
"""Definition of a sample.
First parameter
The actual data that is input to the model
Second parameter
A dictionary containing a named set of meta-data. One the most common is
the ``label`` entry.
"""
class RawDataLoader:
"""A loader object can load samples and labels from storage."""
def sample(self, _: typing.Any) -> Sample:
"""Loads whole samples from media."""
raise NotImplementedError("You must implement the `sample()` method")
def label(self, k: typing.Any) -> int:
"""Loads only sample label from media.
If you do not override this implementation, then, by default,
this method will call :py:meth:`sample` to load the whole sample
and extract the label.
"""
return self.sample(k)[1]["label"]
Transform = typing.Callable[[torch.Tensor], torch.Tensor]
"""A callable, that transforms tensors into (other) tensors.
Typically used in data-processing pipelines inside pytorch.
"""
TransformSequence = typing.Sequence[Transform]
"""A sequence of transforms."""
DatabaseSplit = collections.abc.Mapping[str, typing.Sequence[typing.Any]]
"""The definition of a database split.
A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline.
"""
ConcatDatabaseSplit = collections.abc.Mapping[
str,
typing.Sequence[tuple[typing.Sequence[typing.Any], RawDataLoader]],
]
"""The definition of a complex database split composed of several other splits.
A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline. Objects of this subtype
allow the construction of complex splits composed of cannibalized parts
of other splits. Each split may be assigned a different
:py:class:`RawDataLoader`.
"""
class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
"""Our own definition of a pytorch Dataset, with interesting properties.
We iterate over Sample objects in this case. Our datasets always
provide a dunder len method.
"""
def labels(self) -> list[int]:
"""Returns the integer labels for all samples in the dataset."""
raise NotImplementedError("You must implement the `labels()` method")
DataLoader = torch.utils.data.DataLoader[Sample]
"""Our own augmentation definition of a pytorch DataLoader.
We iterate over Sample objects in this case.
"""