From b61397ae1d7ca2c7f202e66f6ec256a3bfbdbfbf Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 14 Jun 2023 11:00:46 +0200
Subject: [PATCH] Setup datamodule only once per stage, only get needed subsets

---
 .../configs/datasets/shenzhen/default.py      | 14 +--
 src/ptbench/configs/datasets/shenzhen/rgb.py  | 14 +--
 src/ptbench/data/__init__.py                  | 85 ++++++++++---------
 3 files changed, 61 insertions(+), 52 deletions(-)

diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
index 11c285bc..6dc59200 100644
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule):
         )
 
         self.cache_samples = cache_samples
+        self.has_setup_fit = False
 
     def setup(self, stage: str):
         if self.cache_samples:
@@ -56,12 +57,13 @@ class DefaultModule(BaseDataModule):
             loader=samples_loader,
         )
 
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.json_dataset, "default")
+        if not self.has_setup_fit and stage == "fit":
+            (
+                self.train_dataset,
+                self.validation_dataset,
+                self.extra_validation_datasets,
+            ) = return_subsets(self.json_dataset, "default", stage)
+            self.has_setup_fit = True
 
 
 datamodule = DefaultModule
diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/configs/datasets/shenzhen/rgb.py
index 0ceb952e..2506e79d 100644
--- a/src/ptbench/configs/datasets/shenzhen/rgb.py
+++ b/src/ptbench/configs/datasets/shenzhen/rgb.py
@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule):
         )
 
         self.cache_samples = cache_samples
+        self.has_setup_fit = False
 
         self.post_transforms = [
             transforms.ToPILImage(),
@@ -63,12 +64,13 @@ class DefaultModule(BaseDataModule):
             post_transforms=self.post_transforms,
         )
 
-        (
-            self.train_dataset,
-            self.validation_dataset,
-            self.extra_validation_datasets,
-            self.predict_dataset,
-        ) = return_subsets(self.json_dataset, "default")
+        if not self.has_setup_fit and stage == "fit":
+            (
+                self.train_dataset,
+                self.validation_dataset,
+                self.extra_validation_datasets,
+            ) = return_subsets(self.json_dataset, "default", stage)
+            self.has_setup_fit = True
 
 
 datamodule = DefaultModule
diff --git a/src/ptbench/data/__init__.py b/src/ptbench/data/__init__.py
index 259bab57..682d5d1a 100644
--- a/src/ptbench/data/__init__.py
+++ b/src/ptbench/data/__init__.py
@@ -303,49 +303,54 @@ def get_positive_weights(dataset):
     return positive_weights
 
 
-def return_subsets(dataset, protocol):
-    train_dataset = None
-    validation_dataset = None
-    extra_validation_datasets = None
-    predict_dataset = None
+def return_subsets(dataset, protocol, stage):
+    train_set = None
+    valid_set = None
+    extra_valid_sets = None
 
     subsets = dataset.subsets(protocol)
-    if "train" in subsets.keys():
-        train_dataset = SampleListDataset(subsets["train"], [])
 
-    if "validation" in subsets.keys():
-        validation_dataset = SampleListDataset(subsets["validation"], [])
-    else:
-        logger.warning(
-            "No validation dataset found, using training set instead."
-        )
-        validation_dataset = train_dataset
-
-    if "__extra_valid__" in subsets.keys():
-        if not isinstance(subsets["__extra_valid__"], list):
-            raise RuntimeError(
-                f"If present, dataset['__extra_valid__'] must be a list, "
-                f"but you passed a {type(subsets['__extra_valid__'])}, "
-                f"which is invalid."
+    def get_train_subset():
+        if "train" in subsets.keys():
+            nonlocal train_set
+            train_set = SampleListDataset(subsets["train"], [])
+
+    def get_valid_subset():
+        if "validation" in subsets.keys():
+            nonlocal valid_set
+            valid_set = SampleListDataset(subsets["validation"], [])
+        else:
+            logger.warning(
+                "No validation dataset found, using training set instead."
             )
-        logger.info(
-            f"Found {len(subsets['__extra_valid__'])} extra validation "
-            f"set(s) to be tracked during training"
-        )
-        logger.info(
-            "Extra validation sets are NOT used for model checkpointing!"
-        )
-        extra_validation_datasets = SampleListDataset(
-            subsets["__extra_valid__"], []
-        )
-    else:
-        extra_validation_datasets = None
+            if train_set is None:
+                get_train_subset()
+
+            valid_set = train_set
+
+    def get_extra_valid_subset():
+        if "__extra_valid__" in subsets.keys():
+            if not isinstance(subsets["__extra_valid__"], list):
+                raise RuntimeError(
+                    f"If present, dataset['__extra_valid__'] must be a list, "
+                    f"but you passed a {type(subsets['__extra_valid__'])}, "
+                    f"which is invalid."
+                )
+            logger.info(
+                f"Found {len(subsets['__extra_valid__'])} extra validation "
+                f"set(s) to be tracked during training"
+            )
+            logger.info(
+                "Extra validation sets are NOT used for model checkpointing!"
+            )
+            nonlocal extra_valid_sets
+            extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], [])
 
-    predict_dataset = subsets
+    if stage == "fit":
+        get_train_subset()
+        get_valid_subset()
+        get_extra_valid_subset()
 
-    return (
-        train_dataset,
-        validation_dataset,
-        extra_validation_datasets,
-        predict_dataset,
-    )
+        return train_set, valid_set, extra_valid_sets
+    else:
+        raise ValueError(f"Stage {stage} is unknown.")
-- 
GitLab