From 0827ccca3d3af87bf3e6daecaecdde7bf0164096 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Sun, 9 Jul 2023 22:45:00 +0200
Subject: [PATCH] [ptbench.data.datamodule] Implemented typing, added more
 logging, implemented sampler-balancing depending on sample class prevalence,
 parallelized dataset caching

---
 src/ptbench/data/datamodule.py                | 424 +++++++++++-------
 src/ptbench/data/dataset.py                   |  66 ---
 .../{raw_data_loader.py => image_utils.py}    |  20 +-
 src/ptbench/data/shenzhen/default.py          |  38 +-
 src/ptbench/data/shenzhen/loader.py           | 105 +++++
 src/ptbench/data/shenzhen/raw_data_loader.py  |  67 ---
 src/ptbench/data/split.py                     |  28 +-
 src/ptbench/data/typing.py                    |  75 ++++
 8 files changed, 501 insertions(+), 322 deletions(-)
 delete mode 100644 src/ptbench/data/dataset.py
 rename src/ptbench/data/{raw_data_loader.py => image_utils.py} (87%)
 create mode 100644 src/ptbench/data/shenzhen/loader.py
 delete mode 100644 src/ptbench/data/shenzhen/raw_data_loader.py
 create mode 100644 src/ptbench/data/typing.py

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index fc9883d2..a7f0a0fe 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -12,8 +12,16 @@ import lightning
 import torch
 import torch.utils.data
 import torchvision.transforms
+import tqdm
 
-from tqdm import tqdm
+from .typing import (
+    DatabaseSplit,
+    DataLoader,
+    Dataset,
+    RawDataLoader,
+    Sample,
+    Transform,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -64,7 +72,7 @@ def _setup_dataloader_multiproc_parameters(
     return retval
 
 
-class _DelayedLoadingDataset(torch.utils.data.Dataset):
+class _DelayedLoadingDataset(Dataset):
     """A list that loads its samples on demand.
 
     This list mimics a pytorch Dataset, except raw data loading is done
@@ -78,11 +86,8 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
         An iterable containing the raw dataset samples loaded from the database
         splits.
 
-    raw_data_loader
-        A callable that will take the representation of the samples declared
-        inside database splits, load the actual raw data (if one exists), and
-        transform it into a :py:class:`torch.Tensor`containing floats in the
-        interval ``[0, 1]``, and a dictionary with further metadata attributes.
+    loader
+        An object instance that can load samples and labels from storage.
 
     transforms
         A set of transforms that should be applied on-the-fly for this dataset,
@@ -92,30 +97,30 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
     def __init__(
         self,
         split: typing.Sequence[typing.Any],
-        raw_data_loader: typing.Callable[
-            [typing.Any], tuple[torch.Tensor, typing.Mapping]
-        ],
-        transforms: typing.Sequence[
-            typing.Callable[[torch.Tensor], torch.Tensor]
-        ] = [],
+        loader: RawDataLoader,
+        transforms: typing.Sequence[Transform] = [],
     ):
         self.split = split
-        self.raw_data_loader = raw_data_loader
-        # Cannot unpack empty list
-        if len(transforms) > 0:
-            self.transform = torchvision.transforms.Compose([*transforms])
-        else:
-            self.transform = torchvision.transforms.Compose([])
+        self.loader = loader
+        self.transform = torchvision.transforms.Compose(transforms)
+
+    def labels(self) -> list[int]:
+        """Returns the integer labels for all samples in the dataset."""
+        return [self.loader.label(k) for k in self.split]
 
-    def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
-        tensor, metadata = self.raw_data_loader(self.split[key])
+    def __getitem__(self, key: int) -> Sample:
+        tensor, metadata = self.loader.sample(self.split[key])
         return self.transform(tensor), metadata
 
     def __len__(self):
         return len(self.split)
 
+    def __iter__(self):
+        for x in range(len(self)):
+            yield self[x]
 
-class _CachedDataset(torch.utils.data.Dataset):
+
+class _CachedDataset(Dataset):
     """Basically, a list of preloaded samples.
 
     This dataset will load all samples from the split during construction
@@ -130,11 +135,14 @@ class _CachedDataset(torch.utils.data.Dataset):
         An iterable containing the raw dataset samples loaded from the database
         splits.
 
-    raw_data_loader
-        A callable that will take the representation of the samples declared
-        inside database splits, load the actual raw data (if one exists), and
-        transform it into a :py:class:`torch.Tensor`containing floats in the
-        interval ``[0, 1]``, and a dictionary with further metadata attributes.
+    loader
+        An object instance that can load samples and labels from storage.
+
+    parallel
+        Use multiprocessing for data loading: if set to -1 (default), disables
+        multiprocessing data loading.  Set to 0 to enable as many data loading
+        instances as processing cores as available in the system.  Set to >= 1
+        to enable that many multiprocessing instances for data loading.
 
     transforms
         A set of transforms that should be applied to the cached samples for
@@ -145,43 +153,96 @@ class _CachedDataset(torch.utils.data.Dataset):
     def __init__(
         self,
         split: typing.Sequence[typing.Any],
-        raw_data_loader: typing.Callable[
-            [typing.Any], tuple[torch.Tensor, typing.Mapping]
-        ],
-        transforms: typing.Sequence[
-            typing.Callable[[torch.Tensor], torch.Tensor]
-        ] = [],
+        loader: RawDataLoader,
+        parallel: int = -1,
+        transforms: typing.Sequence[Transform] = [],
     ):
-        # Cannot unpack empty list
-        if len(transforms) > 0:
-            self.transform = torchvision.transforms.Compose([*transforms])
-        else:
-            self.transform = torchvision.transforms.Compose([])
-
-        self.data = [raw_data_loader(k) for k in tqdm(split)]
+        self.transform = torchvision.transforms.Compose(transforms)
 
-    def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
+        if parallel < 0:
+            self.data = [
+                loader.sample(k) for k in tqdm.tqdm(split, unit="sample")
+            ]
+        else:
+            instances = parallel or multiprocessing.cpu_count()
+            logger.info(f"Caching dataset using {instances} processes...")
+            with multiprocessing.Pool(instances) as p:
+                self.data = list(
+                    tqdm.tqdm(p.imap(loader.sample, split), total=len(split))
+                )
+
+    def labels(self) -> list[int]:
+        """Returns the integer labels for all samples in the dataset."""
+        return [k[1]["label"] for k in self.data]
+
+    def __getitem__(self, key: int) -> Sample:
         tensor, metadata = self.data[key]
         return self.transform(tensor), metadata
 
     def __len__(self):
         return len(self.data)
 
+    def __iter__(self):
+        for x in range(len(self)):
+            yield self[x]
+
 
-def get_sample_weights(
-    dataset: _DelayedLoadingDataset | _CachedDataset,
-) -> torch.Tensor:
-    """Computes the (inverse) probabilities of samples based on their class.
+def _make_balanced_random_sampler(
+    dataset: Dataset,
+    target: str = "label",
+) -> torch.utils.data.WeightedRandomSampler:
+    """Generates a pytorch sampler that samples according to class
+    probabilities.
 
     This function takes as input a torch Dataset, and computes the weights to
     balance each class in the dataset, and the datasets themselves if one
     passes a :py:class:`torch.utils.data.ConcatDataset`.
 
-    This function assumes labels are stored on the second entry of sample and
-    are integers. If that is not the case, we just balance the overall dataset
-    w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`).
-    For single dataset (not a concatenated one) without labels this function is
-    the same as a no-op.
+    In this implementation, we balance **both** class and dataset-origin
+    probabilities, what you expect for a truly *equitable* random sampler.
+
+    Take this example for illustration:
+
+    * Dataset 1: N = 10 samples, 9 samples with target=0, 1 sample with target=1
+    * Dataset 2: N = 6 samples, 3 samples with target=0, 3 samples with target=1
+
+    So:
+
+    | Dataset | Target | Samples | Weight | Normalised weight |
+    +---------+--------+---------+--------+-------------------+
+    |    1    |   0    |    9    |  1/9   |      1/36         |
+    |    1    |   1    |    1    |  1/1   |      1/4          |
+    |    2    |   0    |    3    |  1/3   |      1/12         |
+    |    2    |   1    |    3    |  1/3   |      1/12         |
+
+    Legend:
+
+    * Weight: the weights computed by this method
+    * Normalised weight: the weight per sample used by the random sampler,
+      after normalising the weights by the sum of all weights in the
+      concatenated dataset, such that the sum of all normalized weights times
+      the number of samples is 1.
+
+    The properties of this algorithm are as follows:
+
+    1. The probability of picking a sample from any target is the same (0.5 in
+       this case).  To verify this, notice that the probability of picking a
+       sample with ``target=0`` is :math:`1/4 x 1 + 1/12 x 3 = 0.5`.
+    2. The probabiility of picking a sample with ``target=0`` from Dataset 2 is
+       3 times higher than those from Dataset 1.  As there are 3 times less
+       samples in Dataset 2 with ``target=0``, this makes choosing samples from
+       Dataset 1 proportionally less likely.
+    3. The probabiility of picking a sample with ``target=1`` from Dataset 2 is
+       3 times lower than those from Dataset 1.  As there are 3 times less
+       samples in Dataset 1 with ``target=1``, this makes choosing samples from
+       Dataset 2 proportionally less likely.
+
+    This function assumes targets are stored on a dictionary entry named
+    ``target`` inside the metadata information for the :py:type:``Sample``, and
+    that its value is integer.
+
+    We then instantiate a pytorch sampler using the inverse probabilities (the
+    more samples of a class, the less likely it becomes to be sampled.
 
 
     Parameters
@@ -189,42 +250,83 @@ def get_sample_weights(
 
     dataset
         An instance of torch Dataset.
-        :py:class:`torch.utils.data.ConcatDataset` are supported
+        :py:class:`torch.utils.data.ConcatDataset` are supported.
+
+    target
+        The name of a metadata key pointing to an integer property that allows
+        balancing the dataset.
 
 
     Returns
     -------
 
-    sample_weights
-        The weights for all the samples in the dataset given as input
+    sampler
+        A sampler, to be used in a dataloader equipped with the same dataset
+        used to calculate the relative sample weights.
+
+
+    Raises
+    ------
+
+    RuntimeError
+        If requested to balance a dataset (single, not-concatenated) without an
+        existing target.
     """
-    retval = []
-
-    def _calculate_dataset_weights(
-        d: _DelayedLoadingDataset | _CachedDataset,
-    ) -> torch.Tensor:
-        """Calculates the weights for one dataset."""
-        if "label" in d[0][1] and isinstance(d[0][1]["label"], int):
-            # there are labels
-            targets = [k[1]["label"] for k in d]
-            counts = collections.Counter(targets)
-            weights = {k: 1.0 / v for k, v in counts.items()}
-            weight_per_sample = [weights[k[1]] for k in d]
-            return torch.tensor(weight_per_sample)
-        else:
-            # no labels, weight dataset samples only by count
-            # n.b.: only works if __len__(d) is implemented, we turn off typechecking
-            weight = 1.0 / len(d)
-            return torch.tensor(len(d) * [weight])
+
+    def _calculate_weights(targets: list[int]) -> list[float]:
+        counts = collections.Counter(targets)
+        weights = {k: 1.0 / v for k, v in counts.items()}
+        return [weights[k] for k in targets]
 
     if isinstance(dataset, torch.utils.data.ConcatDataset):
-        for ds in dataset.datasets:
-            retval.append(_calculate_dataset_weights(ds))  # type: ignore[arg-type]
+        # There are two possible cases: targets/no-targets
+        metadata_example = dataset.datasets[0][0][1]
+        if target in metadata_example and isinstance(
+            metadata_example[target], int
+        ):
+            # there are integer targets, let's balance with those
+            logger.info(
+                f"Balancing sample selection probabilities **and** "
+                f"concatenated-datasets using metadata targets `{target}`"
+            )
+            targets = [
+                k
+                for ds in dataset.datasets
+                for k in typing.cast(Dataset, ds).labels()
+            ]
+            weights = _calculate_weights(targets)
+        else:
+            logger.warning(
+                f"Balancing samples **and** concatenated-datasets "
+                f"WITHOUT metadata targets (`{target}` not available)"
+            )
+            weights = [
+                k
+                for ds in dataset.datasets
+                for k in len(typing.cast(typing.Sized, ds))
+                * [1.0 / len(typing.cast(typing.Sized, ds))]
+            ]
 
-        # Concatenate sample weights from all the datasets
-        return torch.cat(retval)
+            pass
 
-    return _calculate_dataset_weights(dataset)
+    else:
+        metadata_example = dataset[0][1]
+        if target in metadata_example and isinstance(
+            metadata_example[target], int
+        ):
+            logger.info(
+                f"Balancing samples from dataset using metadata "
+                f"targets `{target}`"
+            )
+            weights = _calculate_weights(dataset.labels())
+        else:
+            raise RuntimeError(
+                f"Cannot balance samples without metadata targets `{target}`"
+            )
+
+    return torch.utils.data.WeightedRandomSampler(
+        weights, len(weights), replacement=True
+    )
 
 
 class CachingDataModule(lightning.LightningDataModule):
@@ -253,10 +355,10 @@ class CachingDataModule(lightning.LightningDataModule):
     database_split
         A dictionary that contains string keys representing subset names, and
         values that are iterables over sample representations (potentially on
-        disk).  These objects are passed to the ``raw_data_loader`` for loading
+        disk).  These objects are passed to the ``sample_loader`` for loading
         the sample data (and metadata) in memory.  The objects represented may
         be of any format (e.g. list, dictionary, etc), for as long as the
-        ``raw_data_loader`` can properly handle it.  To check the split and the
+        ``sample_loader`` can properly handle it.  To check the split and the
         loader function works correctly, you may use
         :py:func:`..dataset.check_database_split_loading`.  As is, this class
         expects at least one entry called ``train`` to exist in the input
@@ -265,12 +367,8 @@ class CachingDataModule(lightning.LightningDataModule):
         influence any early stop criteria during training, and are just
         monitored beyond the ``validation`` dataset.
 
-    raw_data_loader
-        A callable that will take the representation of the samples declared
-        inside database splits, load the actual raw data (if one exists), and
-        transform it into a :py:class:`torch.Tensor`containing floats in the
-        interval ``[0, 1]``, and a dictionary with further metadata attributes.
-        Samples can be cached **after** raw data loading.
+    loader
+        An object instance that can load samples and labels from storage.
 
     cache_samples
         If set, then issue raw data loading during ``prepare_data()``, and
@@ -280,25 +378,10 @@ class CachingDataModule(lightning.LightningDataModule):
         this attribute to ``True``.  It is typicall useful for relatively small
         datasets.
 
-    train_sampler
-        If set, then reset the default random sampler from torch with a
-        (potentially) biased one of your choice.  Notice this will not play an
-        effect on the batching strategy, only on the way the samples are picked
-        from the original dataset.  A typical replacement is:
-
-        .. code:: python
-
-           sample_weights = get_sample_weights(dataset)
-           train_sampler = torch.utils.data.WeightedRandomSampler(
-                   sample_weights, len(sample_weights), replacement=True
-                   )
-
-        The variable ``sample_weights`` represents the probabilities (may not
-        sum to one) of picking each particular sample in the dataset for a
-        batch.  We typically set ``replacement=True`` to avoid issues with
-        missing data from one of the classes in mini-batches.  This is
-        particularly important in highly unbalanced datasets.  The function
-        :py:func:`get_sample_weights` may help you in this aspect.
+    balance_sampler_by_class
+        If set, then modifies the random sampler used during training and
+        validation to balance sample picking probability, making sample
+        across classes **and** datasets equitable.
 
     model_transforms
         A list of transforms (torch modules) that will be applied after
@@ -346,17 +429,15 @@ class CachingDataModule(lightning.LightningDataModule):
         to enable that many multiprocessing instances for data loading.
     """
 
+    DatasetDictionary = dict[str, Dataset]
+
     def __init__(
         self,
-        database_split: dict[str, typing.Sequence[typing.Any]],
-        raw_data_loader: typing.Callable[
-            [typing.Any], tuple[torch.Tensor, typing.Mapping]
-        ],
+        database_split: DatabaseSplit,
+        raw_data_loader: RawDataLoader,
         cache_samples: bool = False,
-        train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
-        model_transforms: list[
-            typing.Callable[[torch.Tensor], torch.Tensor]
-        ] = [],
+        balance_sampler_by_class: bool = False,
+        model_transforms: list[Transform] = [],
         batch_size: int = 1,
         batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
@@ -369,7 +450,8 @@ class CachingDataModule(lightning.LightningDataModule):
         self.database_split = database_split
         self.raw_data_loader = raw_data_loader
         self.cache_samples = cache_samples
-        self.train_sampler = train_sampler
+        self._train_sampler = None
+        self.balance_sampler_by_class = balance_sampler_by_class
         self.model_transforms = model_transforms
 
         self.drop_incomplete_batch = drop_incomplete_batch
@@ -380,11 +462,18 @@ class CachingDataModule(lightning.LightningDataModule):
         )  # should only be true if GPU available and using it
 
         # datasets that have been setup() for the current stage
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
+        self._datasets: CachingDataModule.DatasetDictionary = {}
 
     @property
     def parallel(self) -> int:
-        """The parallel property."""
+        """Whether to use multiprocessing for data loading.
+
+        Use multiprocessing for data loading: if set to -1 (default),
+        disables multiprocessing data loading.  Set to 0 to enable as
+        many data loading instances as processing cores as available in
+        the system.  Set to >= 1 to enable that many multiprocessing
+        instances for data loading.
+        """
         return self._parallel
 
     @parallel.setter
@@ -394,7 +483,28 @@ class CachingDataModule(lightning.LightningDataModule):
             value
         )
         # datasets that have been setup() for the current stage
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
+        self._datasets = {}
+
+    @property
+    def balance_sampler_by_class(self):
+        """Whether to balance samples across labels/datasets.
+
+        If set, then modifies the random sampler used during training
+        and validation to balance sample picking probability, making
+        sample across classes **and** datasets equitable.
+        """
+        return self._train_sampler is not None
+
+    @balance_sampler_by_class.setter
+    def balance_sampler_by_class(self, value: bool):
+        if value:
+            if "train" not in self._datasets:
+                self._setup_dataset("train")
+            self._train_sampler = _make_balanced_random_sampler(
+                self._datasets["train"]
+            )
+        else:
+            self._train_sampler = None
 
     def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
         """Coherently sets the batch-chunk-size after validation.
@@ -450,22 +560,39 @@ class CachingDataModule(lightning.LightningDataModule):
         """
 
         if name in self._datasets:
-            logger.info(f"Dataset {name} is already setup.  Not reloading it.")
+            logger.info(
+                f"Dataset `{name}` is already setup. "
+                f"Not re-instantiating it."
+            )
             return
         if self.cache_samples:
-            logger.info(f"Caching {name} dataset")
+            logger.info(
+                f"Loading dataset:`{name}` into memory (caching)."
+                f" Trade-off: CPU RAM: more | Disk: less"
+            )
             self._datasets[name] = _CachedDataset(
                 self.database_split[name],
                 self.raw_data_loader,
+                self.parallel,
                 self.model_transforms,
             )
         else:
+            logger.info(
+                f"Loading dataset:`{name}` without caching."
+                f" Trade-off: CPU RAM: less | Disk: more"
+            )
             self._datasets[name] = _DelayedLoadingDataset(
                 self.database_split[name],
                 self.raw_data_loader,
                 self.model_transforms,
             )
 
+    def _val_dataset_keys(self) -> list[str]:
+        """Returns list of validation dataset names."""
+        return ["validation"] + [
+            k for k in self.database_split.keys() if k.startswith("monitor-")
+        ]
+
     def setup(self, stage: str) -> None:
         """Sets up datasets for different tasks on the pipeline.
 
@@ -493,17 +620,12 @@ class CachingDataModule(lightning.LightningDataModule):
         """
 
         if stage == "fit":
-            self._setup_dataset("train")
-            self._setup_dataset("validation")
-            for k in self.database_split:
-                if k.startswith("monitor-"):
-                    self._setup_dataset(k)
+            for k in ["train"] + self._val_dataset_keys():
+                self._setup_dataset(k)
 
         elif stage == "validate":
-            self._setup_dataset("validation")
-            for k in self.database_split:
-                if k.startswith("monitor-"):
-                    self._setup_dataset(k)
+            for k in self._val_dataset_keys():
+                self._setup_dataset(k)
 
         elif stage == "test":
             self._setup_dataset("test")
@@ -535,32 +657,33 @@ class CachingDataModule(lightning.LightningDataModule):
             * ``predict``: uses only the test dataset
         """
 
-        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}  # type: ignore[no-redef]
+        self._datasets = {}
 
-    def train_dataloader(self) -> torch.utils.data.DataLoader:
+    def train_dataloader(self) -> DataLoader:
         """Returns the train data loader."""
 
         return torch.utils.data.DataLoader(
             self._datasets["train"],
-            shuffle=True,
+            shuffle=(self._train_sampler is None),
             batch_size=self._chunk_size,
             drop_last=self.drop_incomplete_batch,
             pin_memory=self.pin_memory,
-            sampler=self.train_sampler,
+            sampler=self._train_sampler,
             **self._dataloader_multiproc,
         )
 
-    def available_dataset_keys(self) -> typing.KeysView[str]:
-        """Returns all names for datasets that are setup."""
-        return self._datasets.keys()
+    def unshuffled_train_dataloader(self) -> DataLoader:
+        """Returns the train data loader without shuffling."""
 
-    def val_database_split_keys(self) -> list[str]:
-        """Returns list of validation dataset names."""
-        return ["validation"] + [
-            k for k in self.database_split.keys() if k.startswith("monitor-")
-        ]
+        return torch.utils.data.DataLoader(
+            self._datasets["train"],
+            shuffle=False,
+            batch_size=self._chunk_size,
+            drop_last=False,
+            **self._dataloader_multiproc,
+        )
 
-    def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]:
+    def val_dataloader(self) -> dict[str, DataLoader]:
         """Returns the validation data loader(s)"""
 
         validation_loader_opts = {
@@ -571,27 +694,28 @@ class CachingDataModule(lightning.LightningDataModule):
         }
         validation_loader_opts.update(self._dataloader_multiproc)
 
-        # select all keys of interest
         return {
             k: torch.utils.data.DataLoader(
                 self._datasets[k], **validation_loader_opts
             )
-            for k in self.val_database_split_keys()
+            for k in self._val_dataset_keys()
         }
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> dict[str, DataLoader]:
         """Returns the test data loader(s)"""
 
-        return torch.utils.data.DataLoader(
-            self._datasets["test"],
-            batch_size=self._chunk_size,
-            shuffle=False,
-            drop_last=self.drop_incomplete_batch,
-            pin_memory=self.pin_memory,
-            **self._dataloader_multiproc,
+        return dict(
+            test=torch.utils.data.DataLoader(
+                self._datasets["test"],
+                batch_size=self._chunk_size,
+                shuffle=False,
+                drop_last=self.drop_incomplete_batch,
+                pin_memory=self.pin_memory,
+                **self._dataloader_multiproc,
+            )
         )
 
-    def predict_dataloader(self):
+    def predict_dataloader(self) -> dict[str, DataLoader]:
         """Returns the prediction data loader(s)"""
 
         return self.test_dataloader()
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
deleted file mode 100644
index 3529ee7f..00000000
--- a/src/ptbench/data/dataset.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import logging
-
-import torch
-import torch.utils.data
-
-logger = logging.getLogger(__name__)
-
-
-def _get_positive_weights(dataloader):
-    """Compute the positive weights of each class of the dataset to balance the
-    BCEWithLogitsLoss criterion.
-
-    This function takes as input a :py:class:`torch.utils.data.DataLoader`
-    and computes the positive weights of each class to use them to have
-    a balanced loss.
-
-
-    Parameters
-    ----------
-
-    dataloader : :py:class:`torch.utils.data.DataLoader`
-        A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
-
-
-    Returns
-    -------
-
-    positive_weights : :py:class:`torch.Tensor`
-        the positive weight of each class in the dataset given as input
-    """
-    targets = []
-
-    for batch in dataloader:
-        targets.extend(batch[1]["label"])
-
-    targets = torch.tensor(targets)
-
-    # Binary labels
-    if len(list(targets.shape)) == 1:
-        class_sample_count = [
-            float((targets == t).sum().item())
-            for t in torch.unique(targets, sorted=True)
-        ]
-
-        # Divide negatives by positives
-        positive_weights = torch.tensor(
-            [class_sample_count[0] / class_sample_count[1]]
-        ).reshape(-1)
-
-    # Multiclass labels
-    else:
-        class_sample_count = torch.sum(targets, dim=0)
-        negative_class_sample_count = (
-            torch.full((targets.size()[1],), float(targets.size()[0]))
-            - class_sample_count
-        )
-
-        positive_weights = negative_class_sample_count / (
-            class_sample_count + negative_class_sample_count
-        )
-
-    return positive_weights
diff --git a/src/ptbench/data/raw_data_loader.py b/src/ptbench/data/image_utils.py
similarity index 87%
rename from src/ptbench/data/raw_data_loader.py
rename to src/ptbench/data/image_utils.py
index d852743f..ac31b9ce 100644
--- a/src/ptbench/data/raw_data_loader.py
+++ b/src/ptbench/data/image_utils.py
@@ -5,6 +5,8 @@
 
 """Data loading code."""
 
+import pathlib
+
 import numpy
 import PIL.Image
 
@@ -41,58 +43,58 @@ class RemoveBlackBorders:
         return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
 
 
-def load_pil(path):
+def load_pil(path: str | pathlib.Path) -> PIL.Image.Image:
     """Loads a sample data.
 
     Parameters
     ----------
 
-    path : str
+    path
         The full path leading to the image to be loaded
 
 
     Returns
     -------
 
-    image : PIL.Image.Image
+    image
         A PIL image
     """
     return PIL.Image.open(path)
 
 
-def load_pil_baw(path):
+def load_pil_baw(path: str | pathlib.Path) -> PIL.Image.Image:
     """Loads a sample data.
 
     Parameters
     ----------
 
-    path : str
+    path
         The full path leading to the image to be loaded
 
 
     Returns
     -------
 
-    image : PIL.Image.Image
+    image
         A PIL image in grayscale mode
     """
     return load_pil(path).convert("L")
 
 
-def load_pil_rgb(path):
+def load_pil_rgb(path: str | pathlib.Path) -> PIL.Image.Image:
     """Loads a sample data.
 
     Parameters
     ----------
 
-    path : str
+    path
         The full path leading to the image to be loaded
 
 
     Returns
     -------
 
-    image : PIL.Image.Image
+    image
         A PIL image in RGB mode
     """
     return load_pil(path).convert("RGB")
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index 793c3d41..bfe93f44 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -4,19 +4,38 @@
 
 """Shenzhen datamodule for computer-aided diagnosis (default protocol)
 
-See :py:mod:`ptbench.data.shenzhen` for dataset details.
+See :py:mod:`ptbench.data.shenzhen` for more database details.
 
 This configuration:
-* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
-* augmentations: elastic deformation (probability = 80%)
-* output image resolution: 512x512 pixels
+
+* Raw data input (on disk):
+
+  * PNG images (black and white, encoded as color images)
+  * Variable width and height:
+
+    * widths: from 1130 to 3001 pixels
+    * heights: from 948 to 3001 pixels
+
+* Output image:
+
+  * Transforms:
+
+    * Load raw PNG with :py:mod:`PIL`
+    * Remove black borders
+    * Torch resizing(512px, 512px)
+    * Torch center cropping (512px, 512px)
+
+  * Final specifications:
+
+    * Fixed resolution: 512x512 pixels
+    * Color RGB encoding
 """
 
 import importlib.resources
 
 from ..datamodule import CachingDataModule
 from ..split import JSONDatabaseSplit
-from .raw_data_loader import raw_data_loader
+from .loader import RawDataLoader
 
 datamodule = CachingDataModule(
     database_split=JSONDatabaseSplit(
@@ -24,12 +43,5 @@ datamodule = CachingDataModule(
             "default.json.bz2"
         )
     ),
-    raw_data_loader=raw_data_loader,
-    cache_samples=False,
-    # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
-    # model_transforms = [],
-    # batch_size = 1,
-    # batch_chunk_count = 1,
-    # drop_incomplete_batch = False,
-    # parallel = -1,
+    raw_data_loader=RawDataLoader(),
 )
diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py
new file mode 100644
index 00000000..e983fe70
--- /dev/null
+++ b/src/ptbench/data/shenzhen/loader.py
@@ -0,0 +1,105 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Shenzhen dataset for computer-aided diagnosis.
+
+The standard digital image database for Tuberculosis is created by the National
+Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
+Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
+out-patient clinics, and were captured as part of the daily routine using
+Philips DR Digital Diagnose systems.
+
+* Reference: [MONTGOMERY-SHENZHEN-2014]_
+* Original resolution (height x width or width x height): 3000 x 3000 or less
+* Split reference: none
+* Protocol ``default``:
+
+  * Training samples: 64% of TB and healthy CXR (including labels)
+  * Validation samples: 16% of TB and healthy CXR (including labels)
+  * Test samples: 20% of TB and healthy CXR (including labels)
+"""
+
+import os
+
+import torchvision.transforms
+
+from ...utils.rc import load_rc
+from ..image_utils import RemoveBlackBorders, load_pil_baw
+from ..typing import RawDataLoader as _BaseRawDataLoader
+from ..typing import Sample
+
+
+class RawDataLoader(_BaseRawDataLoader):
+    """A specialized raw-data-loader for the Shenzen dataset.
+
+    Attributes
+    ----------
+
+    datadir
+        This variable contains the base directory where the database raw data
+        is stored.
+
+    transform
+        Transforms that are always applied to the loaded raw images.
+    """
+
+    datadir: str
+    transform: torchvision.transforms.Compose
+
+    def __init__(self):
+        self.datadir = load_rc().get(
+            "datadir.shenzhen", os.path.realpath(os.curdir)
+        )
+
+        self.transform = torchvision.transforms.Compose(
+            [
+                RemoveBlackBorders(),
+                torchvision.transforms.Resize(512),
+                torchvision.transforms.CenterCrop(512),
+                torchvision.transforms.ToTensor(),
+            ]
+        )
+
+    def sample(self, sample: tuple[str, int]) -> Sample:
+        """Loads a single image sample from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        sample
+            The sample representation
+        """
+        tensor = self.transform(
+            load_pil_baw(os.path.join(self.datadir, sample[0]))
+        )
+        return tensor, dict(label=sample[1])  # type: ignore[arg-type]
+
+    def label(self, sample: tuple[str, int]) -> int:
+        """Loads a single image sample label from the disk.
+
+        Parameters
+        ----------
+
+        sample:
+            A tuple containing the path suffix, within the dataset root folder,
+            where to find the image to be loaded, and an integer, representing the
+            sample label.
+
+
+        Returns
+        -------
+
+        label
+            The integer label associated with the sample
+        """
+        return sample[1]
diff --git a/src/ptbench/data/shenzhen/raw_data_loader.py b/src/ptbench/data/shenzhen/raw_data_loader.py
deleted file mode 100644
index 0c2c39df..00000000
--- a/src/ptbench/data/shenzhen/raw_data_loader.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-"""Shenzhen dataset for computer-aided diagnosis.
-
-The standard digital image database for Tuberculosis is created by the National
-Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
-Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
-out-patient clinics, and were captured as part of the daily routine using
-Philips DR Digital Diagnose systems.
-
-* Reference: [MONTGOMERY-SHENZHEN-2014]_
-* Original resolution (height x width or width x height): 3000 x 3000 or less
-* Split reference: none
-* Protocol ``default``:
-
-  * Training samples: 64% of TB and healthy CXR (including labels)
-  * Validation samples: 16% of TB and healthy CXR (including labels)
-  * Test samples: 20% of TB and healthy CXR (including labels)
-"""
-
-import os
-import typing
-
-import torch.nn
-import torchvision.transforms
-
-from ...utils.rc import load_rc
-from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
-
-_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
-"""This variable contains the base directory where the database raw data is
-stored."""
-
-_transform = torchvision.transforms.Compose(
-    [
-        RemoveBlackBorders(),
-        torchvision.transforms.Resize(512),
-        torchvision.transforms.CenterCrop(512),
-        torchvision.transforms.ToTensor(),
-    ]
-)
-"""Transforms that are always applied to the loaded raw images."""
-
-
-def raw_data_loader(
-    sample: tuple[str, int]
-) -> tuple[torch.Tensor, typing.Mapping]:
-    """Loads a single image sample from the disk.
-
-    Parameters
-    ----------
-
-    img_path
-        The path suffix, within the dataset root folder, where to find the
-        image to be loaded.
-
-
-    Returns
-    -------
-
-    image
-        A PIL image in grayscale mode
-    """
-    tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
-    return tensor, dict(label=sample[1])  # type: ignore[arg-type]
diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py
index 06e813ee..78fe7e33 100644
--- a/src/ptbench/data/split.py
+++ b/src/ptbench/data/split.py
@@ -2,7 +2,6 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import collections.abc
 import csv
 import importlib.abc
 import json
@@ -12,13 +11,12 @@ import typing
 
 import torch
 
+from .typing import DatabaseSplit, RawDataLoader
+
 logger = logging.getLogger(__name__)
 
 
-class JSONDatabaseSplit(
-    dict,
-    typing.Mapping[str, typing.Sequence[typing.Any]],
-):
+class JSONDatabaseSplit(DatabaseSplit):
     """Defines a loader that understands a database split (train, test, etc) in
     JSON format.
 
@@ -73,7 +71,7 @@ class JSONDatabaseSplit(
         self.path = path
         self.subsets = self._load_split_from_disk()
 
-    def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
+    def _load_split_from_disk(self) -> DatabaseSplit:
         """Loads all subsets in a split from its file system representation.
 
         This method will load JSON information for the current split and return
@@ -109,7 +107,7 @@ class JSONDatabaseSplit(
         return len(self.subsets)
 
 
-class CSVDatabaseSplit(collections.abc.Mapping):
+class CSVDatabaseSplit(DatabaseSplit):
     """Defines a loader that understands a database split (train, test, etc) in
     CSV format.
 
@@ -154,9 +152,7 @@ class CSVDatabaseSplit(collections.abc.Mapping):
         self.directory = directory
         self.subsets = self._load_split_from_disk()
 
-    def _load_split_from_disk(
-        self,
-    ) -> dict[str, list[typing.Any]]:
+    def _load_split_from_disk(self) -> DatabaseSplit:
         """Loads all subsets in a split from its file system representation.
 
         This method will load CSV information for the current split and return all
@@ -171,7 +167,7 @@ class CSVDatabaseSplit(collections.abc.Mapping):
             A dictionary mapping subset names to lists of JSON objects
         """
 
-        retval = {}
+        retval: DatabaseSplit = {}
         for subset in self.directory.iterdir():
             if str(subset).endswith(".csv.bz2"):
                 logger.debug(f"Loading database split from {subset}...")
@@ -204,8 +200,8 @@ class CSVDatabaseSplit(collections.abc.Mapping):
 
 
 def check_database_split_loading(
-    database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
-    loader: typing.Callable[[typing.Any], torch.Tensor],
+    database_split: DatabaseSplit,
+    loader: RawDataLoader,
     limit: int = 0,
 ) -> int:
     """For each subset in the split, check if all data can be correctly loaded
@@ -224,9 +220,7 @@ def check_database_split_loading(
         object that represents a single sample.
 
     loader
-        A callable that transforms sample entries in the database split
-        into :py:class:`torch.Tensor` objects that can be used for training
-        or inference.
+        A loader object that knows how to handle full-samples or just labels.
 
     limit
         Maximum number of samples to check (in each split/subset
@@ -248,7 +242,7 @@ def check_database_split_loading(
         samples = subset if not limit else subset[:limit]
         for pos, sample in enumerate(samples):
             try:
-                data = loader(sample)
+                data, _ = loader.sample(sample)
                 assert isinstance(data, torch.Tensor)
             except Exception as e:
                 logger.info(
diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py
new file mode 100644
index 00000000..344c1294
--- /dev/null
+++ b/src/ptbench/data/typing.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Defines most common types used in code."""
+
+import typing
+
+import torch
+import torch.utils.data
+
+Sample = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
+"""Definition of a sample.
+
+First parameter
+    The actual data that is input to the model
+
+Second parameter
+    A dictionary containing a named set of meta-data.  One the most common is
+    the ``label`` entry.
+"""
+
+
+class RawDataLoader:
+    """A loader object can load samples and labels from storage."""
+
+    def sample(self, _: typing.Any) -> Sample:
+        """Loads whole samples from media."""
+        raise NotImplementedError("You must implement the `sample()` method")
+
+    def label(self, k: typing.Any) -> int:
+        """Loads only sample label from media.
+
+        If you do not override this implementation, then, by default,
+        this method will call :py:meth:`sample` to load the whole sample
+        and extract the label.
+        """
+        return self.sample(k)[1]["label"]
+
+
+Transform = typing.Callable[[torch.Tensor], torch.Tensor]
+"""A callable, that transforms tensors into (other) tensors.
+
+Typically used in data-processing pipelines inside pytorch.
+"""
+
+TransformSequence = typing.Sequence[Transform]
+"""A sequence of transforms."""
+
+DatabaseSplit = dict[str, typing.Sequence[typing.Any]]
+"""The definition of a database script.
+
+A database script maps subset names to sequences of objects that,
+through RawDataLoader's eventually become Samples in the processing
+pipeline.
+"""
+
+
+class Dataset(torch.utils.data.Dataset[Sample], typing.Iterable, typing.Sized):
+    """Our own definition of a pytorch Dataset, with interesting properties.
+
+    We iterate over Sample objects in this case.  Our datasets always
+    provide a dunder len method.
+    """
+
+    def labels(self) -> list[int]:
+        """Returns the integer labels for all samples in the dataset."""
+        raise NotImplementedError("You must implement the `labels()` method")
+
+
+DataLoader = torch.utils.data.DataLoader[Sample]
+"""Our own augmentation definition of a pytorch DataLoader.
+
+We iterate over Sample objects in this case.
+"""
-- 
GitLab