From 8afb13e5911df0846dc3e4f2a1ceaf5d0251ad4d Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Thu, 20 Jul 2023 23:33:38 +0200
Subject: [PATCH] [data.datamodule] Double-check model_transforms are set
 before datasets are instantiated; Remove model_transforms from constructor
 (non-sensical); Improve documentation on model-transforms

---
 src/ptbench/data/datamodule.py | 33 +++++++++++++++++++++++----------
 1 file changed, 23 insertions(+), 10 deletions(-)

diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index 5c940918..c8c33dec 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -338,14 +338,6 @@ class CachingDataModule(lightning.LightningDataModule):
         validation to balance sample picking probability, making sample
         across classes **and** datasets equitable.
 
-    model_transforms
-        A list of transforms (torch modules) that will be applied after
-        raw-data-loading, and just before data is fed into the model or
-        eventual data-augmentation transformations for all data loaders
-        produced by this data module.  This part of the pipeline receives data
-        as output by the raw-data-loader, or model-related transforms (e.g.
-        resize adaptions), if any is specified.
-
     batch_size
         Number of samples in every **training** batch (this parameter affects
         memory requirements for the network).  If the number of samples in the
@@ -382,6 +374,21 @@ class CachingDataModule(lightning.LightningDataModule):
         multiprocessing data loading.  Set to 0 to enable as many data loading
         instances as processing cores as available in the system.  Set to >= 1
         to enable that many multiprocessing instances for data loading.
+
+
+    Attributes
+    ----------
+
+    model_transforms
+        A list of transforms (torch modules) that will be applied after
+        raw-data-loading, and just before data is fed into the model or
+        eventual data-augmentation transformations for all data loaders
+        produced by this data module.  This part of the pipeline receives data
+        as output by the raw-data-loader, or model-related transforms (e.g.
+        resize adaptions), if any is specified.  If data is cached, it is
+        cached **after** model-transforms are applied, as that is a potential
+        memory saver (e.g., if it contains a resizing operation to smaller
+        images).
     """
 
     DatasetDictionary = dict[str, Dataset]
@@ -392,7 +399,6 @@ class CachingDataModule(lightning.LightningDataModule):
         raw_data_loader: RawDataLoader,
         cache_samples: bool = False,
         balance_sampler_by_class: bool = False,
-        model_transforms: list[Transform] = [],
         batch_size: int = 1,
         batch_chunk_count: int = 1,
         drop_incomplete_batch: bool = False,
@@ -407,7 +413,7 @@ class CachingDataModule(lightning.LightningDataModule):
         self.cache_samples = cache_samples
         self._train_sampler = None
         self.balance_sampler_by_class = balance_sampler_by_class
-        self.model_transforms = model_transforms
+        self.model_transforms: list[Transform] | None = None
 
         self.drop_incomplete_batch = drop_incomplete_batch
         self.parallel = parallel  # immutable, otherwise would need to call
@@ -551,6 +557,13 @@ class CachingDataModule(lightning.LightningDataModule):
             Name of the dataset to setup.
         """
 
+        if self.model_transforms is None:
+            raise RuntimeError(
+                "Parameter `model_transforms` has not yet been "
+                "set.  If you do not have model transforms, then "
+                "set it to an empty list."
+            )
+
         if name in self._datasets:
             logger.info(
                 f"Dataset `{name}` is already setup. "
-- 
GitLab