From c464ae3eed032bd5e90bf9673720958bd180cdaf Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 3 Jul 2023 13:10:29 +0200
Subject: [PATCH] Updates to the data module and models

---
 src/ptbench/data/datamodule.py | 133 ++++++++++++++++++---------------
 src/ptbench/engine/trainer.py  |  47 +++++++-----
 src/ptbench/models/densenet.py |  15 ++++
 src/ptbench/models/pasa.py     |  10 +++
 src/ptbench/scripts/train.py   |  61 ++++++---------
 5 files changed, 146 insertions(+), 120 deletions(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 633ebd62..98e003c5 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -3,6 +3,7 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 
 import collections
+import logging
 import multiprocessing
 import sys
 import typing
@@ -12,6 +13,8 @@ import torch
 import torch.utils.data
 import torchvision.transforms
 
+logger = logging.getLogger(__name__)
+
 
 def _setup_dataloader_multiproc_parameters(
     parallel: int,
@@ -80,7 +83,8 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
         interval ``[0, 1]``, and a dictionary with further metadata attributes.
 
     transforms
-        A set of transforms that should be applied on-the-fly for this dataset.
+        A set of transforms that should be applied on-the-fly for this dataset,
+        to fit the output of the raw-data-loader to the model of interest.
     """
 
     def __init__(
@@ -113,7 +117,8 @@ class _CachedDataset(torch.utils.data.Dataset):
     """Basically, a list of preloaded samples.
 
     This dataset will load all samples from the split during construction
-    instead of delaying that to the indexing.
+    instead of delaying that to the indexing.  Beyong raw-data-loading,
+    ``transforms`` given upon construction contribute to the cached samples.
 
 
     Parameters
@@ -130,7 +135,9 @@ class _CachedDataset(torch.utils.data.Dataset):
         interval ``[0, 1]``, and a dictionary with further metadata attributes.
 
     transforms
-        A set of transforms that should be applied on-the-fly for this dataset.
+        A set of transforms that should be applied to the cached samples for
+        this dataset, to fit the output of the raw-data-loader to the model of
+        interest.
     """
 
     def __init__(
@@ -143,8 +150,8 @@ class _CachedDataset(torch.utils.data.Dataset):
             typing.Callable[[torch.Tensor], torch.Tensor]
         ] = [],
     ):
-        self.data = [raw_data_loader(k) for k in split]
         self.transform = torchvision.transforms.Compose(*transforms)
+        self.data = [raw_data_loader(k) for k in split]
 
     def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
         tensor, metadata = self.data[key]
@@ -286,21 +293,13 @@ class CachingDataModule(lightning.LightningDataModule):
         particularly important in highly unbalanced datasets.  The function
         :py:func:`get_sample_weights` may help you in this aspect.
 
-    data_augmentations
-        A list of torchvision transforms (torch modules) that will be applied
-        on training set samples to create data augmentations during the
-        training of a model. Augmentation transform pipelines are applied
-        *after* the raw data is loaded, and before ``model_transforms``.
-        Augmentation transforms assume they receive a torch tensor representing
-        an image as input (see :py:class:`torchvision.transforms.ToTensor` for
-        details), in the range ``[0, 1]``.
-
     model_transforms
-        A list of torchvision transforms (torch modules) that will be applied
-        after data augmentation transforms, and just before data is fed into
-        the model for all data loaders produced by this data module.  This part
-        of the pipeline receives data as output by the raw-data-loader, or from
-        data augmentations, if any is specified.
+        A list of transforms (torch modules) that will be applied after
+        raw-data-loading, and just before data is fed into the model or
+        eventual data-augmentation transformations for all data loaders
+        produced by this data module.  This part of the pipeline receives data
+        as output by the raw-data-loader, or model-related transforms (e.g.
+        resize adaptions), if any is specified.
 
     batch_size
         Number of samples in every **training** batch (this parameter affects
@@ -348,9 +347,6 @@ class CachingDataModule(lightning.LightningDataModule):
         ],
         cache_samples: bool = False,
         train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
-        data_augmentations: list[
-            typing.Callable[[torch.Tensor], torch.Tensor]
-        ] = [],
         model_transforms: list[
             typing.Callable[[torch.Tensor], torch.Tensor]
         ] = [],
@@ -367,7 +363,6 @@ class CachingDataModule(lightning.LightningDataModule):
         self.raw_data_loader = raw_data_loader
         self.cache_samples = cache_samples
         self.train_sampler = train_sampler
-        self.data_augmentations = data_augmentations
         self.model_transforms = model_transforms
 
         self.drop_incomplete_batch = drop_incomplete_batch
@@ -437,10 +432,36 @@ class CachingDataModule(lightning.LightningDataModule):
         self._batch_chunk_count = batch_chunk_count
         self._chunk_size = self._batch_size // self._batch_chunk_count
 
+    def _setup_dataset(self, name: str) -> None:
+        """Sets-up a single dataset from the input data split.
+
+        Parameters
+        ----------
+
+        name
+            Name of the dataset to setup.
+        """
+
+        if name in self._datasets:
+            logger.info(f"Dataset {name} is already setup.  Not reloading it.")
+            return
+        if self.cache_samples:
+            self._datasets[name] = _CachedDataset(
+                self.database_split[name],
+                self.raw_data_loader,
+                self.model_transforms,
+            )
+        else:
+            self._datasets[name] = _DelayedLoadingDataset(
+                self.database_split[name],
+                self.raw_data_loader,
+                self.model_transforms,
+            )
+
     def setup(self, stage: str) -> None:
         """Sets up datasets for different tasks on the pipeline.
 
-        This method should setup (load, pre-process, etc) all subsets required
+        This method should setup (load, pre-process, etc) all datasets required
         for a particular ``stage`` (fit, validate, test, predict), and keep
         them ready to be used on one of the `_dataloader()` functions that are
         pertinent for such stage.
@@ -463,62 +484,50 @@ class CachingDataModule(lightning.LightningDataModule):
             * ``predict``: uses only the test dataset
         """
 
-        def _setup(name, transforms):
-            if self.cache_samples:
-                self._datasets[name] = _CachedDataset(
-                    self.database_split[name], self.raw_data_loader, transforms
-                )
-            else:
-                self._datasets[name] = _DelayedLoadingDataset(
-                    self.database_split[name], self.raw_data_loader, transforms
-                )
-
         if stage == "fit":
-            _setup("train", self.data_augmentations + self.model_transforms)
-            _setup("validation", self.model_transforms)
+            self._setup_dataset("train")
+            self._setup_dataset("validation")
             for k in self.database_split:
                 if k.startswith("monitor-"):
-                    _setup(k, self.model_transforms)
+                    self._setup_dataset(k)
 
         elif stage == "validate":
-            _setup("validation", self.model_transforms)
+            self._setup_dataset("validation")
             for k in self.database_split:
                 if k.startswith("monitor-"):
-                    _setup(k, self.model_transforms)
+                    self._setup_dataset(k)
 
         elif stage == "test":
-            _setup("test", self.model_transforms)
+            self._setup_dataset("test")
 
         elif stage == "predict":
-            _setup("test", self.model_transforms)
+            self._setup_dataset("test")
+
+    def teardown(self, stage: str) -> None:
+        """Unset-up datasets for different tasks on the pipeline.
+
+        This method unsets (unload, remove from memory, etc) all datasets required
+        for a particular ``stage`` (fit, validate, test, predict).
 
-    def unaugmented_train_dataloader(self) -> torch.utils.data.DataLoader:
-        """Returns a version of the train dataloader without augmentations.
+        If you have set ``cache_samples``, samples are loaded, this may
+        effectivley release all the associated memory.
 
-        Use this method to obtain a version of the train dataloader without
-        augmentations, to compute input normalisation factors (e.g. mean and
-        standard deviation or min-max parameterisations).
 
+        Parameters
+        ----------
 
-        Returns
-        -------
+        stage
+            Name of the stage to which the teardown is applicable.  Can be one of
+            ``fit``, ``validate``, ``test`` or ``predict``.  Each stage
+            typically uses the following data loaders:
 
-        dataloader
-            The unaugmented train dataloader
+            * ``fit``: uses both train and validation datasets
+            * ``validate``: uses only the validation dataset
+            * ``test``: uses only the test dataset
+            * ``predict``: uses only the test dataset
         """
-        dataset = _DelayedLoadingDataset(
-            self.database_split["train"],
-            self.raw_data_loader,
-            self.model_transforms,
-        )
-        return torch.utils.data.DataLoader(
-            dataset,
-            shuffle=False,
-            batch_size=self._chunk_size,
-            drop_last=self.drop_incomplete_batch,
-            pin_memory=self.pin_memory,
-            **self._dataloader_multiproc,
-        )
+
+        self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
 
     def train_dataloader(self) -> torch.utils.data.DataLoader:
         """Returns the train data loader."""
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 2c8bdc55..6b156f86 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -7,10 +7,10 @@ import logging
 import os
 import shutil
 
-from lightning.pytorch import Trainer
-from lightning.pytorch.callbacks import ModelCheckpoint
-from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
-from lightning.pytorch.utilities.model_summary import ModelSummary
+import lightning.pytorch
+import lightning.pytorch.callbacks
+import lightning.pytorch.loggers
+import torch.nn
 
 from ..utils.accelerator import AcceleratorProcessor
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
@@ -19,7 +19,7 @@ from .callbacks import LoggingCallback
 logger = logging.getLogger(__name__)
 
 
-def check_gpu(device):
+def check_gpu(device: str) -> None:
     """Check the device type and the availability of GPU.
 
     Parameters
@@ -35,32 +35,37 @@ def check_gpu(device):
         ), f"Device set to '{device}', but nvidia-smi is not installed"
 
 
-def save_model_summary(output_folder, model):
+def save_model_summary(
+    output_folder: str, model: torch.nn.Module
+) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
     """Save a little summary of the model in a txt file.
 
     Parameters
     ----------
 
-    output_folder : str
+    output_folder
         output path
 
-    model : :py:class:`torch.nn.Module`
+    model
         Network (e.g. driu, hed, unet)
 
     Returns
     -------
-    r : str
+    r
         The model summary in a text format.
 
-    n : int
+    n
         The number of parameters of the model.
     """
     summary_path = os.path.join(output_folder, "model_summary.txt")
     logger.info(f"Saving model summary at {summary_path}...")
     with open(summary_path, "w") as f:
-        summary = ModelSummary(model, max_depth=-1)
+        summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1)
         f.write(str(summary))
-    return summary, ModelSummary(model).total_parameters
+    return (
+        summary,
+        lightning.pytorch.callbacks.ModelSummary(model).total_parameters,
+    )
 
 
 def static_information_to_csv(static_logfile_name, device, n):
@@ -70,7 +75,8 @@ def static_information_to_csv(static_logfile_name, device, n):
     ----------
 
     static_logfile_name : str
-        The static file name which is a join between the output folder and "constant.csv"
+        The static file name which is a join between the output folder and
+        "constant.csv"
     """
     if os.path.exists(static_logfile_name):
         backup = static_logfile_name + "~"
@@ -188,7 +194,8 @@ def run(
         not save intermediary checkpoints.
 
     accelerator : str
-        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
+        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
+        device can also be specified (gpu:0)
 
     arguments : dict
         Start and end epochs:
@@ -215,10 +222,12 @@ def run(
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
-    r, n = save_model_summary(output_folder, model)
+    _, n = save_model_summary(output_folder, model)
 
-    csv_logger = CSVLogger(output_folder, "logs_csv")
-    tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
+    csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
+    tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
+        output_folder, "logs_tensorboard"
+    )
 
     resource_monitor = ResourceMonitor(
         interval=monitoring_interval,
@@ -227,7 +236,7 @@ def run(
         logging_level=logging.ERROR,
     )
 
-    checkpoint_callback = ModelCheckpoint(
+    checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
         output_folder,
         "model_lowest_valid_loss",
         save_last=True,
@@ -251,7 +260,7 @@ def run(
         devices = accelerator_processor.device
 
     with resource_monitor:
-        trainer = Trainer(
+        trainer = lightning.pytorch.Trainer(
             accelerator=accelerator_processor.accelerator,
             devices=devices,
             max_epochs=max_epoch,
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 1aa1b268..e61d7dec 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -2,11 +2,15 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import logging
+
 import lightning.pytorch as pl
 import torch
 import torch.nn as nn
 import torchvision.models as models
 
+logger = logging.getLogger(__name__)
+
 
 class Densenet(pl.LightningModule):
     """Densenet module.
@@ -25,6 +29,8 @@ class Densenet(pl.LightningModule):
     ):
         super().__init__()
 
+        # Saves all hyper parameters declared on __init__ into ``self.hparams`.
+        # You can access those by their name, like `self.hparams.optimizer`
         self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
 
         self.name = "Densenet"
@@ -55,10 +61,19 @@ class Densenet(pl.LightningModule):
         if self.pretrained:
             from .normalizer import make_imagenet_normalizer
 
+            logger.warning(
+                "ImageNet pre-trained densenet model - NOT"
+                "computing z-norm factors from training data. "
+                "Using preset factors from torchvision."
+            )
             self.normalizer = make_imagenet_normalizer()
         else:
             from .normalizer import make_z_normalizer
 
+            logger.info(
+                "Uninitialised densenet model - "
+                "computing z-norm factors from training data."
+            )
             self.normalizer = make_z_normalizer(dataloader)
 
     def training_step(self, batch, batch_idx):
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 61105dbe..fbc73f81 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -2,12 +2,16 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import logging
+
 import lightning.pytorch as pl
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils.data
 
+logger = logging.getLogger(__name__)
+
 
 class PASA(pl.LightningModule):
     """PASA module.
@@ -24,6 +28,8 @@ class PASA(pl.LightningModule):
     ):
         super().__init__()
 
+        # Saves all hyper parameters declared on __init__ into ``self.hparams`.
+        # You can access those by their name, like `self.hparams.criterion`
         self.save_hyperparameters()
 
         self.name = "pasa"
@@ -137,6 +143,10 @@ class PASA(pl.LightningModule):
         """
         from .normalizer import make_z_normalizer
 
+        logger.info(
+            "Uninitialised densenet model - "
+            "computing z-norm factors from training data."
+        )
         self.normalizer = make_z_normalizer(dataloader)
 
     def training_step(self, batch, _):
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index eeb5b869..01f294d7 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -6,9 +6,6 @@ import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
-from lightning.pytorch import seed_everything
-
-from ..utils.checkpointer import get_checkpoint
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -65,21 +62,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--criterion",
-    help="A loss function to compute the CNN error for every sample "
-    "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
-    required=True,
-    cls=ResourceOption,
-)
-@click.option(
-    "--criterion-valid",
-    help="A specific loss function for the validation set to compute the CNN"
-    "error for every sample respecting the PyTorch API for loss functions"
-    "(see torch.nn.modules.loss)",
-    required=False,
-    cls=ResourceOption,
-)
 @click.option(
     "--batch-size",
     "-b",
@@ -159,7 +141,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--accelerator",
     "-a",
-    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
+    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). '
+    "The device can also be specified (gpu:0)",
     show_default=True,
     required=True,
     default="cpu",
@@ -167,7 +150,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @click.option(
     "--cache-samples",
-    help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
+    help="If set to True, loads the sample into memory, "
+    "otherwise loads them at runtime.",
     required=True,
     show_default=True,
     default=False,
@@ -196,16 +180,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     default=-1,
     cls=ResourceOption,
 )
-@click.option(
-    "--normalization",
-    "-n",
-    help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
-    " 'current' for parameters of the current trainset, "
-    "'none' for no normalization.",
-    required=False,
-    default="none",
-    cls=ResourceOption,
-)
 @click.option(
     "--monitoring-interval",
     "-I",
@@ -224,7 +198,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @click.option(
     "--resume-from",
-    help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
+    help="Which checkpoint to resume training from. If set, can be one of "
+    "`best`, `last`, or a path to a model checkpoint.",
     type=str,
     required=False,
     default=None,
@@ -238,20 +213,16 @@ def train(
     batch_size,
     batch_chunk_count,
     drop_incomplete_batch,
-    criterion,
-    criterion_valid,
     datamodule,
     checkpoint_period,
     accelerator,
     cache_samples,
     seed,
     parallel,
-    normalization,
     monitoring_interval,
     resume_from,
-    **_,
 ):
-    """Trains an CNN to perform tuberculosis detection.
+    """Trains an CNN to perform image classification.
 
     Training is performed for a configurable number of epochs, and
     generates at least a final_model.pth.  It may also generate a number
@@ -263,21 +234,33 @@ def train(
     import torch.cuda
     import torch.nn
 
-    from ..data.dataset import reweight_BCEWithLogitsLoss
+    from lightning.pytorch import seed_everything
+
     from ..engine.trainer import run
+    from ..utils.checkpointer import get_checkpoint
 
     seed_everything(seed)
 
     checkpoint_file = get_checkpoint(output_folder, resume_from)
 
+    # reset datamodule with user configurable options
     datamodule.set_chunk_size(batch_size, batch_chunk_count)
+    datamodule.drop_incomplete_batch = drop_incomplete_batch
+    datamodule.cache_samples = cache_samples
     datamodule.parallel = parallel
 
     datamodule.prepare_data()
     datamodule.setup(stage="fit")
 
-    reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
-    model.set_normalizer(datamodule.unaugmented_train_dataloader())
+    # Sets the model normalizer with the unaugmented-train-subset.
+    # this call may be a NOOP, if the model was pre-trained and expects
+    # different weights for the normalisation layer.
+    model.set_normalizer(datamodule.train_dataloader())
+
+    # Rebalances the loss criterion based on the relative proportion of class
+    # examples available in the training set.  Also affects the validation loss
+    # if a validation set is available on the data module.
+    model.set_bce_loss_weights(datamodule)
 
     arguments = {}
     arguments["max_epoch"] = epochs
-- 
GitLab