From 8d11b566a1e6011183f89cbdf3eb7e316cdc6ecd Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 29 Jun 2023 13:25:36 +0200
Subject: [PATCH] Pair programming with @dcarron

---
 src/ptbench/data/datamodule.py                | 167 +++++++---
 src/ptbench/data/dataset.py                   | 291 ------------------
 src/ptbench/data/padchest/__init__.py         |   2 +-
 .../data/{loader.py => raw_data_loader.py}    |  33 ++
 src/ptbench/data/shenzhen/default.py          |   8 +-
 .../{loader.py => raw_data_loader.py}         |   3 +-
 src/ptbench/data/split.py                     | 258 ++++++++++++++++
 src/ptbench/data/transforms.py                |  37 +--
 src/ptbench/models/densenet.py                |  15 +
 src/ptbench/models/normalizer.py              |  39 ++-
 src/ptbench/models/pasa.py                    |   7 +
 11 files changed, 468 insertions(+), 392 deletions(-)
 rename src/ptbench/data/{loader.py => raw_data_loader.py} (51%)
 rename src/ptbench/data/shenzhen/{loader.py => raw_data_loader.py} (96%)
 create mode 100644 src/ptbench/data/split.py

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 11e76697..6670a309 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -10,11 +10,7 @@ import typing
 import lightning
 import torch
 import torch.utils.data
-
-from clapper.logging import setup
-
-# TODO: No logging on this module...
-logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+import torchvision.transforms
 
 
 def _setup_dataloader_multiproc_parameters(
@@ -93,11 +89,13 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
         raw_data_loader: typing.Callable[
             [typing.Any], tuple[torch.Tensor, typing.Mapping]
         ],
-        transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
+        transforms: typing.Sequence[
+            typing.Callable[[torch.Tensor], torch.Tensor]
+        ] = [],
     ):
         self.split = split
         self.raw_data_loader = raw_data_loader
-        self.transform = torch.nn.Sequential(*transforms)
+        self.transform = torchvision.transforms.Compose(*transforms)
 
     def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
         tensor, metadata = self.raw_data_loader(self.split[key])
@@ -137,10 +135,12 @@ class _CachedDataset(torch.utils.data.Dataset):
         raw_data_loader: typing.Callable[
             [typing.Any], tuple[torch.Tensor, typing.Mapping]
         ],
-        transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
+        transforms: typing.Sequence[
+            typing.Callable[[torch.Tensor], torch.Tensor]
+        ] = [],
     ):
         self.data = [raw_data_loader(k) for k in split]
-        self.transform = torch.nn.Sequential(*transforms)
+        self.transform = torchvision.transforms.Compose(*transforms)
 
     def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
         tensor, metadata = self.data[key]
@@ -344,22 +344,21 @@ class CachingDataModule(lightning.LightningDataModule):
         ],
         cache_samples: bool = False,
         train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
-        data_augmentations: list[torch.nn.Module] = [],
-        model_transforms: list[torch.nn.Module] = [],
+        data_augmentations: list[
+            typing.Callable[[torch.Tensor], torch.Tensor]
+        ] = [],
+        model_transforms: list[
+            typing.Callable[[torch.Tensor], torch.Tensor]
+        ] = [],
         batch_size: int = 1,
         batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
         parallel: int = -1,
     ):
-        # validation
-        if batch_size % batch_chunk_count != 0:
-            raise RuntimeError(
-                f"batch_size ({batch_size}) must be divisible by "
-                f"batch_chunk_size ({batch_chunk_count})."
-            )
-
         super().__init__()
 
+        self.set_chunk_size(batch_size, batch_chunk_count)
+
         self.database_split = database_split
         self.raw_data_loader = raw_data_loader
         self.cache_samples = cache_samples
@@ -367,16 +366,8 @@ class CachingDataModule(lightning.LightningDataModule):
         self.data_augmentations = data_augmentations
         self.model_transforms = model_transforms
 
-        self._batch_size = batch_size
-        self._batch_chunk_count = batch_chunk_count
-        self._chunk_size = self._batch_size // self._batch_chunk_count
-
         self.drop_incomplete_batch = drop_incomplete_batch
-        self._parallel = parallel  # immutable, otherwise would need to call
-        # the next function again
-        self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
-            parallel
-        )
+        self.parallel = parallel  # immutable, otherwise would need to call
 
         self.pin_memory = (
             torch.cuda.is_available()
@@ -385,6 +376,63 @@ class CachingDataModule(lightning.LightningDataModule):
         # datasets that have been setup() for the current stage
         self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
 
+    @property
+    def parallel(self) -> int:
+        """The parallel property."""
+        return self._parallel
+
+    @parallel.setter
+    def parallel(self, value: int) -> None:
+        self._parallel = value
+        self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
+            value
+        )
+        # datasets that have been setup() for the current stage
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
+
+    def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
+        """Coherently sets the batch-chunk-size after validation.
+
+        Parameters
+        ----------
+
+        batch_size
+            Number of samples in every **training** batch (this parameter affects
+            memory requirements for the network).  If the number of samples in the
+            batch is larger than the total number of samples available for
+            training, this value is truncated.  If this number is smaller, then
+            batches of the specified size are created and fed to the network  until
+            there are no more new samples to feed (epoch is finished).  If the
+            total number of training samples is not a multiple of the batch-size,
+            the last batch will be smaller than the first, unless
+            ``drop_incomplete_batch`` is set to ``true``, in which case this batch
+            is not used.
+
+        batch_chunk_count
+            Number of chunks in every batch (this parameter affects memory
+            requirements for the network). The number of samples loaded for every
+            iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
+            needs to be divisible by ``batch_chunk_count``, otherwise an error will
+            be raised. This parameter is used to reduce number of samples loaded in
+            each iteration, in order to reduce the memory usage in exchange for
+            processing time (more iterations). This is specially interesting whe
+            one is running with GPUs with limited RAM. The default of 1 forces the
+            whole batch to be processed at once. Otherwise the batch is broken into
+            batch-chunk-count pieces, and gradients are accumulated to complete
+            each batch.
+        """
+
+        # validation
+        if batch_size % batch_chunk_count != 0:
+            raise RuntimeError(
+                f"batch_size ({batch_size}) must be divisible by "
+                f"batch_chunk_size ({batch_chunk_count})."
+            )
+
+        self._batch_size = batch_size
+        self._batch_chunk_count = batch_chunk_count
+        self._chunk_size = self._batch_size // self._batch_chunk_count
+
     def setup(self, stage: str) -> None:
         """Sets up datasets for different tasks on the pipeline.
 
@@ -440,11 +488,40 @@ class CachingDataModule(lightning.LightningDataModule):
         elif stage == "predict":
             _setup("test", self.model_transforms)
 
-    def train_dataloader(self):
+    def unaugmented_train_dataloader(self) -> torch.utils.data.DataLoader:
+        """Returns a version of the train dataloader without augmentations.
+
+        Use this method to obtain a version of the train dataloader without
+        augmentations, to compute input normalisation factors (e.g. mean and
+        standard deviation or min-max parameterisations).
+
+
+        Returns
+        -------
+
+        dataloader
+            The unaugmented train dataloader
+        """
+        dataset = _DelayedLoadingDataset(
+            self.database_split["train"],
+            self.raw_data_loader,
+            self.model_transforms,
+        )
+        return torch.utils.data.DataLoader(
+            dataset,
+            shuffle=False,
+            batch_size=self._chunk_size,
+            drop_last=self.drop_incomplete_batch,
+            pin_memory=self.pin_memory,
+            **self._dataloader_multiproc,
+        )
+
+    def train_dataloader(self) -> torch.utils.data.DataLoader:
         """Returns the train data loader."""
 
         return torch.utils.data.DataLoader(
             self._datasets["train"],
+            shuffle=True,
             batch_size=self._chunk_size,
             drop_last=self.drop_incomplete_batch,
             pin_memory=self.pin_memory,
@@ -452,14 +529,19 @@ class CachingDataModule(lightning.LightningDataModule):
             **self._dataloader_multiproc,
         )
 
-    def val_dataloader(self):
-        """Returns the validation data loader(s)"""
+    def available_dataset_keys(self) -> typing.KeysView[str]:
+        """Returns all names for datasets that are setup."""
+        return self._datasets.keys()
 
-        extra_valid = [
+    def val_database_split_keys(self) -> list[str]:
+        """Returns list of validation dataset names."""
+        return ["validation"] + [
             k for k in self.database_split.keys() if k.startswith("monitor-")
         ]
 
-        # TODO: do we really need the train sampler here?
+    def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]:
+        """Returns the validation data loader(s)"""
+
         validation_loader_opts = {
             "batch_size": self._chunk_size,
             "shuffle": False,
@@ -468,22 +550,13 @@ class CachingDataModule(lightning.LightningDataModule):
         }
         validation_loader_opts.update(self._dataloader_multiproc)
 
-        # TODO: not sure this is the right way to handle multiple validation
-        # loaders, please check and fix
-        if not extra_valid:
-            return torch.utils.data.DataLoader(
-                self._datasets["validation"],
-                **validation_loader_opts,
+        # select all keys of interest
+        return {
+            k: torch.utils.data.DataLoader(
+                self._datasets[k], **validation_loader_opts
             )
-
-        else:
-            return [
-                torch.utils.data.DataLoader(
-                    self._datasets[k],
-                    **validation_loader_opts,
-                )
-                for k in ["validation"] + extra_valid
-            ]
+            for k in self.val_database_split_keys()
+        }
 
     def test_dataloader(self):
         """Returns the test data loader(s)"""
diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py
index 6b373db4..c85fbcb3 100644
--- a/src/ptbench/data/dataset.py
+++ b/src/ptbench/data/dataset.py
@@ -2,13 +2,7 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
-import collections.abc
-import csv
-import importlib.abc
-import json
 import logging
-import pathlib
-import typing
 
 import torch
 import torch.utils.data
@@ -16,249 +10,6 @@ import torch.utils.data
 logger = logging.getLogger(__name__)
 
 
-class JSONDatabaseSplit(
-    dict,
-    typing.Mapping[str, typing.Sequence[typing.Any]],
-):
-    """Defines a loader that understands a database split (train, test, etc) in
-    JSON format.
-
-    To create a new database split, you need to provide a JSON formatted
-    dictionary in a file, with contents similar to the following:
-
-    .. code-block:: json
-
-       {
-           "subset1": [
-               [
-                   "sample1-data1",
-                   "sample1-data2",
-                   "sample1-data3",
-               ],
-               [
-                   "sample2-data1",
-                   "sample2-data2",
-                   "sample2-data3",
-               ]
-           ],
-           "subset2": [
-               [
-                   "sample42-data1",
-                   "sample42-data2",
-                   "sample42-data3",
-               ],
-           ]
-       }
-
-    Your database split many contain any number of subsets (dictionary keys).
-    For simplicity, we recommend all sample entries are formatted similarly so
-    that raw-data-loading is simplified.  Use the function
-    :py:func:`check_database_split_loading` to test raw data loading and fine
-    tune the dataset split, or its loading.
-
-    Objects of this class behave like a dictionary in which keys are subset
-    names in the split, and values represent samples data and meta-data.
-
-
-    Parameters
-    ----------
-
-    path
-        Absolute path to a JSON formatted file containing the database split to be
-        recognized by this object.
-    """
-
-    def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
-        if isinstance(path, str):
-            path = pathlib.Path(path)
-        self.path = path
-        self.subsets = self._load_split_from_disk()
-
-    def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
-        """Loads all subsets in a split from its file system representation.
-
-        This method will load JSON information for the current split and return
-        all subsets of the given split after converting each entry through the
-        loader function.
-
-
-        Returns
-        -------
-
-        subsets : dict
-            A dictionary mapping subset names to lists of JSON objects
-        """
-
-        if str(self.path).endswith(".bz2"):
-            logger.debug(f"Loading database split from {str(self.path)}...")
-            with __import__("bz2").open(self.path) as f:
-                return json.load(f)
-        else:
-            with self.path.open() as f:
-                return json.load(f)
-
-    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
-        """Accesses subset ``key`` from this split."""
-        return self.subsets[key]
-
-    def __iter__(self):
-        """Iterates over the subsets."""
-        return iter(self.subsets)
-
-    def __len__(self) -> int:
-        """How many subsets we currently have."""
-        return len(self.subsets)
-
-
-class CSVDatabaseSplit(collections.abc.Mapping):
-    """Defines a loader that understands a database split (train, test, etc) in
-    CSV format.
-
-    To create a new database split, you need to provide one or more CSV
-    formatted files, each representing a subset of this split, containing the
-    sample data (one per row).  Example:
-
-    Inside the directory ``my-split/``, one can file files ``train.csv``,
-    ``validation.csv``, and ``test.csv``.  Each file has a structure similar to
-    the following:
-
-    .. code-block:: text
-
-       sample1-value1,sample1-value2,sample1-value3
-       sample2-value1,sample2-value2,sample2-value3
-       ...
-
-    Each file in the provided directory defines the subset name on the split.
-    So, the file ``train.csv`` will contain the data from the ``train`` subset,
-    and so on.
-
-    Objects of this class behave like a dictionary in which keys are subset
-    names in the split, and values represent samples data and meta-data.
-
-
-    Parameters
-    ----------
-
-    directory
-        Absolute path to a directory containing the database split layed down
-        as a set of CSV files, one per subset.
-    """
-
-    def __init__(
-        self, directory: pathlib.Path | str | importlib.abc.Traversable
-    ):
-        if isinstance(directory, str):
-            directory = pathlib.Path(directory)
-        assert (
-            directory.is_dir()
-        ), f"`{str(directory)}` is not a valid directory"
-        self.directory = directory
-        self.subsets = self._load_split_from_disk()
-
-    def _load_split_from_disk(
-        self,
-    ) -> dict[str, list[typing.Any]]:
-        """Loads all subsets in a split from its file system representation.
-
-        This method will load CSV information for the current split and return all
-        subsets of the given split after converting each entry through the
-        loader function.
-
-
-        Returns
-        -------
-
-        subsets : dict
-            A dictionary mapping subset names to lists of JSON objects
-        """
-
-        retval = {}
-        for subset in self.directory.iterdir():
-            if str(subset).endswith(".csv.bz2"):
-                logger.debug(f"Loading database split from {subset}...")
-                with __import__("bz2").open(subset) as f:
-                    reader = csv.reader(f)
-                    retval[subset.name[: -len(".csv.bz2")]] = [
-                        k for k in reader
-                    ]
-            elif str(subset).endswith(".csv"):
-                with subset.open() as f:
-                    reader = csv.reader(f)
-                    retval[subset.name[: -len(".csv")]] = [k for k in reader]
-            else:
-                logger.debug(
-                    f"Ignoring file {subset} in CSVDatabaseSplit readout"
-                )
-        return retval
-
-    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
-        """Accesses subset ``key`` from this split."""
-        return self.subsets[key]
-
-    def __iter__(self):
-        """Iterates over the subsets."""
-        return iter(self.subsets)
-
-    def __len__(self) -> int:
-        """How many subsets we currently have."""
-        return len(self.subsets)
-
-
-def check_database_split_loading(
-    database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
-    loader: typing.Callable[[typing.Any], torch.Tensor],
-    limit: int = 0,
-) -> int:
-    """For each subset in the split, check if all data can be correctly loaded
-    using the provided loader function.
-
-    This function will return the number of errors loading samples, and will
-    log more detailed information to the logging stream.
-
-
-    Parameters
-    ----------
-
-    database_split
-        A mapping that, contains the database split.  Each key represents the
-        name of a subset in the split.  Each value is a (potentially complex)
-        object that represents a single sample.
-
-    loader
-        A callable that transforms sample entries in the database split
-        into :py:class:`torch.Tensor` objects that can be used for training
-        or inference.
-
-    limit
-        Maximum number of samples to check (in each split/subset
-        combination) in this dataset.  If set to zero, then check
-        everything.
-
-
-    Returns
-    -------
-
-    errors
-        Number of errors found
-    """
-    logger.info(
-        "Checking if can load all samples in all subsets of this split..."
-    )
-    errors = 0
-    for subset in database_split.keys():
-        samples = subset if not limit else subset[:limit]
-        for pos, sample in enumerate(samples):
-            try:
-                data = loader(sample)
-                assert isinstance(data, torch.Tensor)
-            except Exception as e:
-                logger.info(
-                    f"Found error loading entry {pos} in subset `{subset}`: {e}"
-                )
-                errors += 1
-    return errors
-
-
 def get_positive_weights(dataset):
     """Compute the positive weights of each class of the dataset to balance the
     BCEWithLogitsLoss criterion.
@@ -350,45 +101,3 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
             )
         else:
             logger.warning("Weighted valid criterion not supported")
-
-
-def normalize_data(normalization, model, datamodule):
-    from torch.utils.data import DataLoader
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="fit")
-
-    train_dataset = datamodule.train_dataset
-
-    # Create z-normalization model layer if needed
-    if normalization == "imagenet":
-        model.normalizer.set_mean_std(
-            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
-        )
-        logger.info("Z-normalization with ImageNet mean and std")
-    elif normalization == "current":
-        # Compute mean/std of current train subset
-        temp_dl = DataLoader(
-            dataset=train_dataset, batch_size=len(train_dataset)
-        )
-
-        data = next(iter(temp_dl))
-        mean = data[1].mean(dim=[0, 2, 3])
-        std = data[1].std(dim=[0, 2, 3])
-
-        model.normalizer.set_mean_std(mean, std)
-
-        # Format mean and std for logging
-        mean = str(
-            [
-                round(x, 3)
-                for x in ((mean * 10**3).round() / (10**3)).tolist()
-            ]
-        )
-        std = str(
-            [
-                round(x, 3)
-                for x in ((std * 10**3).round() / (10**3)).tolist()
-            ]
-        )
-        logger.info(f"Z-normalization with mean {mean} and std {std}")
diff --git a/src/ptbench/data/padchest/__init__.py b/src/ptbench/data/padchest/__init__.py
index af1dd3ec..f6f39a9e 100644
--- a/src/ptbench/data/padchest/__init__.py
+++ b/src/ptbench/data/padchest/__init__.py
@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
 def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
     import torchvision.transforms as transforms
 
-    from ..transforms import SingleAutoLevel16to8
+    from ..loader import SingleAutoLevel16to8
 
     post_transforms = []
     if not RGB:
diff --git a/src/ptbench/data/loader.py b/src/ptbench/data/raw_data_loader.py
similarity index 51%
rename from src/ptbench/data/loader.py
rename to src/ptbench/data/raw_data_loader.py
index d6e86ed0..d852743f 100644
--- a/src/ptbench/data/loader.py
+++ b/src/ptbench/data/raw_data_loader.py
@@ -5,9 +5,42 @@
 
 """Data loading code."""
 
+import numpy
 import PIL.Image
 
 
+class SingleAutoLevel16to8:
+    """Converts a 16-bit image to 8-bit representation using "auto-level".
+
+    This transform assumes that the input image is gray-scaled.
+
+    To auto-level, we calculate the maximum and the minimum of the image, and
+    consider such a range should be mapped to the [0,255] range of the
+    destination image.
+    """
+
+    def __call__(self, img):
+        imin, imax = img.getextrema()
+        irange = imax - imin
+        return PIL.Image.fromarray(
+            numpy.round(
+                255.0 * (numpy.array(img).astype(float) - imin) / irange
+            ).astype("uint8"),
+        ).convert("L")
+
+
+class RemoveBlackBorders:
+    """Remove black borders of CXR."""
+
+    def __init__(self, threshold=0):
+        self.threshold = threshold
+
+    def __call__(self, img):
+        img = numpy.asarray(img)
+        mask = numpy.asarray(img) > self.threshold
+        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
+
+
 def load_pil(path):
     """Loads a sample data.
 
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index c9068112..8d943292 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -15,13 +15,15 @@ This configuration:
 import importlib.resources
 
 from ..datamodule import CachingDataModule
-from ..dataset import JSONDatabaseSplit
+from ..split import JSONDatabaseSplit
 from ..transforms import ElasticDeformation
-from .loader import raw_data_loader
+from .raw_data_loader import raw_data_loader
 
 datamodule = CachingDataModule(
     database_split=JSONDatabaseSplit(
-        importlib.resources.files(__name__).joinpath("default.json.bz2")
+        importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
+            "default.json.bz2"
+        )
     ),
     raw_data_loader=raw_data_loader,
     cache_samples=False,
diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/raw_data_loader.py
similarity index 96%
rename from src/ptbench/data/shenzhen/loader.py
rename to src/ptbench/data/shenzhen/raw_data_loader.py
index 6578fc8f..0c2c39df 100644
--- a/src/ptbench/data/shenzhen/loader.py
+++ b/src/ptbench/data/shenzhen/raw_data_loader.py
@@ -27,8 +27,7 @@ import torch.nn
 import torchvision.transforms
 
 from ...utils.rc import load_rc
-from ..loader import load_pil_baw
-from ..transforms import RemoveBlackBorders
+from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
 
 _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
 """This variable contains the base directory where the database raw data is
diff --git a/src/ptbench/data/split.py b/src/ptbench/data/split.py
new file mode 100644
index 00000000..06e813ee
--- /dev/null
+++ b/src/ptbench/data/split.py
@@ -0,0 +1,258 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import collections.abc
+import csv
+import importlib.abc
+import json
+import logging
+import pathlib
+import typing
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+class JSONDatabaseSplit(
+    dict,
+    typing.Mapping[str, typing.Sequence[typing.Any]],
+):
+    """Defines a loader that understands a database split (train, test, etc) in
+    JSON format.
+
+    To create a new database split, you need to provide a JSON formatted
+    dictionary in a file, with contents similar to the following:
+
+    .. code-block:: json
+
+       {
+           "subset1": [
+               [
+                   "sample1-data1",
+                   "sample1-data2",
+                   "sample1-data3",
+               ],
+               [
+                   "sample2-data1",
+                   "sample2-data2",
+                   "sample2-data3",
+               ]
+           ],
+           "subset2": [
+               [
+                   "sample42-data1",
+                   "sample42-data2",
+                   "sample42-data3",
+               ],
+           ]
+       }
+
+    Your database split many contain any number of subsets (dictionary keys).
+    For simplicity, we recommend all sample entries are formatted similarly so
+    that raw-data-loading is simplified.  Use the function
+    :py:func:`check_database_split_loading` to test raw data loading and fine
+    tune the dataset split, or its loading.
+
+    Objects of this class behave like a dictionary in which keys are subset
+    names in the split, and values represent samples data and meta-data.
+
+
+    Parameters
+    ----------
+
+    path
+        Absolute path to a JSON formatted file containing the database split to be
+        recognized by this object.
+    """
+
+    def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
+        if isinstance(path, str):
+            path = pathlib.Path(path)
+        self.path = path
+        self.subsets = self._load_split_from_disk()
+
+    def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
+        """Loads all subsets in a split from its file system representation.
+
+        This method will load JSON information for the current split and return
+        all subsets of the given split after converting each entry through the
+        loader function.
+
+
+        Returns
+        -------
+
+        subsets : dict
+            A dictionary mapping subset names to lists of JSON objects
+        """
+
+        if str(self.path).endswith(".bz2"):
+            logger.debug(f"Loading database split from {str(self.path)}...")
+            with __import__("bz2").open(self.path) as f:
+                return json.load(f)
+        else:
+            with self.path.open() as f:
+                return json.load(f)
+
+    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
+        """Accesses subset ``key`` from this split."""
+        return self.subsets[key]
+
+    def __iter__(self):
+        """Iterates over the subsets."""
+        return iter(self.subsets)
+
+    def __len__(self) -> int:
+        """How many subsets we currently have."""
+        return len(self.subsets)
+
+
+class CSVDatabaseSplit(collections.abc.Mapping):
+    """Defines a loader that understands a database split (train, test, etc) in
+    CSV format.
+
+    To create a new database split, you need to provide one or more CSV
+    formatted files, each representing a subset of this split, containing the
+    sample data (one per row).  Example:
+
+    Inside the directory ``my-split/``, one can file files ``train.csv``,
+    ``validation.csv``, and ``test.csv``.  Each file has a structure similar to
+    the following:
+
+    .. code-block:: text
+
+       sample1-value1,sample1-value2,sample1-value3
+       sample2-value1,sample2-value2,sample2-value3
+       ...
+
+    Each file in the provided directory defines the subset name on the split.
+    So, the file ``train.csv`` will contain the data from the ``train`` subset,
+    and so on.
+
+    Objects of this class behave like a dictionary in which keys are subset
+    names in the split, and values represent samples data and meta-data.
+
+
+    Parameters
+    ----------
+
+    directory
+        Absolute path to a directory containing the database split layed down
+        as a set of CSV files, one per subset.
+    """
+
+    def __init__(
+        self, directory: pathlib.Path | str | importlib.abc.Traversable
+    ):
+        if isinstance(directory, str):
+            directory = pathlib.Path(directory)
+        assert (
+            directory.is_dir()
+        ), f"`{str(directory)}` is not a valid directory"
+        self.directory = directory
+        self.subsets = self._load_split_from_disk()
+
+    def _load_split_from_disk(
+        self,
+    ) -> dict[str, list[typing.Any]]:
+        """Loads all subsets in a split from its file system representation.
+
+        This method will load CSV information for the current split and return all
+        subsets of the given split after converting each entry through the
+        loader function.
+
+
+        Returns
+        -------
+
+        subsets : dict
+            A dictionary mapping subset names to lists of JSON objects
+        """
+
+        retval = {}
+        for subset in self.directory.iterdir():
+            if str(subset).endswith(".csv.bz2"):
+                logger.debug(f"Loading database split from {subset}...")
+                with __import__("bz2").open(subset) as f:
+                    reader = csv.reader(f)
+                    retval[subset.name[: -len(".csv.bz2")]] = [
+                        k for k in reader
+                    ]
+            elif str(subset).endswith(".csv"):
+                with subset.open() as f:
+                    reader = csv.reader(f)
+                    retval[subset.name[: -len(".csv")]] = [k for k in reader]
+            else:
+                logger.debug(
+                    f"Ignoring file {subset} in CSVDatabaseSplit readout"
+                )
+        return retval
+
+    def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
+        """Accesses subset ``key`` from this split."""
+        return self.subsets[key]
+
+    def __iter__(self):
+        """Iterates over the subsets."""
+        return iter(self.subsets)
+
+    def __len__(self) -> int:
+        """How many subsets we currently have."""
+        return len(self.subsets)
+
+
+def check_database_split_loading(
+    database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
+    loader: typing.Callable[[typing.Any], torch.Tensor],
+    limit: int = 0,
+) -> int:
+    """For each subset in the split, check if all data can be correctly loaded
+    using the provided loader function.
+
+    This function will return the number of errors loading samples, and will
+    log more detailed information to the logging stream.
+
+
+    Parameters
+    ----------
+
+    database_split
+        A mapping that, contains the database split.  Each key represents the
+        name of a subset in the split.  Each value is a (potentially complex)
+        object that represents a single sample.
+
+    loader
+        A callable that transforms sample entries in the database split
+        into :py:class:`torch.Tensor` objects that can be used for training
+        or inference.
+
+    limit
+        Maximum number of samples to check (in each split/subset
+        combination) in this dataset.  If set to zero, then check
+        everything.
+
+
+    Returns
+    -------
+
+    errors
+        Number of errors found
+    """
+    logger.info(
+        "Checking if can load all samples in all subsets of this split..."
+    )
+    errors = 0
+    for subset in database_split.keys():
+        samples = subset if not limit else subset[:limit]
+        for pos, sample in enumerate(samples):
+            try:
+                data = loader(sample)
+                assert isinstance(data, torch.Tensor)
+            except Exception as e:
+                logger.info(
+                    f"Found error loading entry {pos} in subset `{subset}`: {e}"
+                )
+                errors += 1
+    return errors
diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index c85c3e9d..cf1946e9 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -22,44 +22,9 @@ from scipy.ndimage import gaussian_filter, map_coordinates
 from torchvision import transforms
 
 
-class SingleAutoLevel16to8:
-    """Converts a 16-bit image to 8-bit representation using "auto-level".
-
-    This transform assumes that the input image is gray-scaled.
-
-    To auto-level, we calculate the maximum and the minimum of the
-    image, and
-    consider such a range should be mapped to the [0,255] range of the
-    destination image.
-    """
-
-    def __call__(self, img):
-        imin, imax = img.getextrema()
-        irange = imax - imin
-        return PIL.Image.fromarray(
-            numpy.round(
-                255.0 * (numpy.array(img).astype(float) - imin) / irange
-            ).astype("uint8"),
-        ).convert("L")
-
-
-class RemoveBlackBorders:
-    """Remove black borders of CXR."""
-
-    def __init__(self, threshold=0):
-        self.threshold = threshold
-
-    def __call__(self, img):
-        img = numpy.asarray(img)
-        mask = numpy.asarray(img) > self.threshold
-        return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
-
-
 class ElasticDeformation:
     """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
 
-    TODO: needs to be converted into a torch.nn.Module to become scriptable!
-
     Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
     """
 
@@ -70,7 +35,7 @@ class ElasticDeformation:
         spline_order=1,
         mode="nearest",
         random_state=numpy.random,
-        p=1,
+        p=1.0,
     ):
         self.alpha = alpha
         self.sigma = sigma
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index a7cf9d56..52bd27f0 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -48,6 +48,21 @@ class Densenet(pl.LightningModule):
 
         return x
 
+    def set_normalizer(self, dataloader):
+        """TODO: Write this function to set the Normalizer
+
+        This function is NOOP if ``pretrained = True`` (normalizer set to
+        imagenet weights, during contruction).
+        """
+        if self.pretrained:
+            from .normalizer import TorchVisionNormalizer
+
+            self.normalizer = TorchVisionNormalizer(..., ...)
+        else:
+            from .normalizer import get_znorm_normalizer
+
+            self.normalizer = get_znorm_normalizer(dataloader)
+
     def training_step(self, batch, batch_idx):
         images = batch[1]
         labels = batch[2]
diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py
index aa2142a6..10320ab1 100644
--- a/src/ptbench/models/normalizer.py
+++ b/src/ptbench/models/normalizer.py
@@ -6,6 +6,7 @@
 
 import torch
 import torch.nn
+import torch.utils.data
 
 
 class TorchVisionNormalizer(torch.nn.Module):
@@ -20,19 +21,33 @@ class TorchVisionNormalizer(torch.nn.Module):
         Number of images channels fed to the model
     """
 
-    def __init__(self, nb_channels=3):
+    def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
         super().__init__()
-        mean = torch.zeros(nb_channels)[None, :, None, None]
-        std = torch.ones(nb_channels)[None, :, None, None]
-        self.register_buffer("mean", mean)
-        self.register_buffer("std", std)
+        assert len(subtract) == len(divide), "TODO"
+        assert len(subtract) in (1, 3), "TODO"
+        self.subtract = subtract
+        self.divided = divide
+        subtract = torch.zeros(len(subtract.shape))[None, :, None, None]
+        divide = torch.ones(len(divide.shape))[None, :, None, None]
+        self.register_buffer("subtract", subtract)
+        self.register_buffer("divide", divide)
         self.name = "torchvision-normalizer"
 
-    def set_mean_std(self, mean, std):
-        mean = torch.as_tensor(mean)[None, :, None, None]
-        std = torch.as_tensor(std)[None, :, None, None]
-        self.register_buffer("mean", mean)
-        self.register_buffer("std", std)
+    def forward(self, inputs: torch.Tensor):
+        """inputs shape [batches, planes, height, width]"""
+        return inputs.sub(self.subtract).div(self.divide)
 
-    def forward(self, inputs):
-        return inputs.sub(self.mean).div(self.std)
+
+def get_znorm_normalizer(
+    dataloader: torch.utils.data.DataLoader,
+) -> TorchVisionNormalizer:
+    # TODO: Fix this function to use unaugmented training set
+    # TODO: This function is only applicable IFF we are not fine-tuning (ie.
+    #       model does not re-use weights from imagenet training!)
+    # TODO: Add type hints
+    # TODO: Add documentation
+
+    # 1 extract mean/std from dataloader
+
+    # 2 return TorchVisionNormalizer(mean, std)
+    pass
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 9da9702f..d14ea5d4 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -6,6 +6,7 @@ import lightning.pytorch as pl
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import torch.utils.data
 
 from .normalizer import TorchVisionNormalizer
 
@@ -127,6 +128,12 @@ class PASA(pl.LightningModule):
 
         return x
 
+    def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
+        """TODO: Write this function documentation"""
+        from .normalizer import get_znorm_normalizer
+
+        self.normalizer = get_znorm_normalizer(dataloader)
+
     def training_step(self, batch, batch_idx):
         images = batch["data"]
         labels = batch["label"]
-- 
GitLab