Skip to content
Snippets Groups Projects
typing.py 2.80 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Defines most common types used in code."""

import collections.abc
import typing

import torch
import torch.utils.data

Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
"""Definition of a sample.

First parameter
    The actual data that is input to the model

Second parameter
    A dictionary containing a named set of meta-data.  One the most common is
    the ``label`` entry.
"""


class RawDataLoader:
    """A loader object can load samples and labels from storage."""

    def sample(self, _: typing.Any) -> Sample:
        """Loads whole samples from media."""
        raise NotImplementedError("You must implement the `sample()` method")

    def label(self, k: typing.Any) -> int:
        """Loads only sample label from media.

        If you do not override this implementation, then, by default,
        this method will call :py:meth:`sample` to load the whole sample
        and extract the label.
        """
        return self.sample(k)[1]["label"]


Transform = typing.Callable[[torch.Tensor], torch.Tensor]
"""A callable, that transforms tensors into (other) tensors.

Typically used in data-processing pipelines inside pytorch.
"""

TransformSequence = typing.Sequence[Transform]
"""A sequence of transforms."""

DatabaseSplit = collections.abc.Mapping[str, typing.Sequence[typing.Any]]
"""The definition of a database split.

A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline.
"""

ConcatDatabaseSplit = collections.abc.Mapping[
    str,
    typing.Sequence[tuple[typing.Sequence[typing.Any], RawDataLoader]],
]
"""The definition of a complex database split composed of several other splits.

A database split maps dataset (subset) names to sequences of objects
that, through :py:class:`RawDataLoader`s, eventually become
:py:class:`Sample`s in the processing pipeline. Objects of this subtype
allow the construction of complex splits composed of cannibalized parts
of other splits.  Each split may be assigned a different
:py:class:`RawDataLoader`.
"""


class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
    """Our own definition of a pytorch Dataset, with interesting properties.

    We iterate over Sample objects in this case.  Our datasets always
    provide a dunder len method.
    """

    def labels(self) -> list[int]:
        """Returns the integer labels for all samples in the dataset."""
        raise NotImplementedError("You must implement the `labels()` method")


DataLoader = torch.utils.data.DataLoader[Sample]
"""Our own augmentation definition of a pytorch DataLoader.

We iterate over Sample objects in this case.
"""