From b0138f2fc36fecff0c995c752216780e1be3452a Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 5 Jun 2023 16:01:38 +0200
Subject: [PATCH] Function to instantiate DataModule and retrieve dataset from
 it (DRY)

---
 src/ptbench/data/base_datamodule.py  | 13 ++++++++++++
 src/ptbench/data/mc_ch/default.py    | 29 +++++++++------------------
 src/ptbench/data/mc_ch/fold_0.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_0_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_1.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_1_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_2.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_2_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_3.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_3_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_4.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_4_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_5.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_5_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_6.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_6_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_7.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_7_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_8.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_8_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_9.py     | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/fold_9_rgb.py | 30 +++++++++-------------------
 src/ptbench/data/mc_ch/rgb.py        | 30 +++++++++-------------------
 23 files changed, 211 insertions(+), 461 deletions(-)

diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index fb0970f0..5e656d42 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -92,3 +92,16 @@ class BaseDataModule(pl.LightningDataModule):
             shuffle=False,
             pin_memory=self.pin_memory,
         )
+
+
+def get_dataset_from_module(module, stage, **module_args):
+    """Instantiates a DataModule and retrieves the corresponding dataset.
+
+    Useful when combining multiple datasets.
+    """
+    module_instance = module(**module_args)
+    module_instance.prepare_data()
+    module_instance.setup(stage=stage)
+    dataset = module_instance.dataset
+
+    return dataset
diff --git a/src/ptbench/data/mc_ch/default.py b/src/ptbench/data/mc_ch/default.py
index 0af1d686..cf901fdb 100644
--- a/src/ptbench/data/mc_ch/default.py
+++ b/src/ptbench/data/mc_ch/default.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.default import datamodule as mc_datamodule
 from ..shenzhen.default import datamodule as ch_datamodule
 
@@ -37,27 +37,16 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
 
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_0.py b/src/ptbench/data/mc_ch/fold_0.py
index e151c0e0..eeb8f512 100644
--- a/src/ptbench/data/mc_ch/fold_0.py
+++ b/src/ptbench/data/mc_ch/fold_0.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_0 import datamodule as mc_datamodule
 from ..shenzhen.fold_0 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_0_rgb.py b/src/ptbench/data/mc_ch/fold_0_rgb.py
index 56502b24..6c8c5aeb 100644
--- a/src/ptbench/data/mc_ch/fold_0_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_0_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_0_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_0_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_1.py b/src/ptbench/data/mc_ch/fold_1.py
index 732513da..b6bcb0f7 100644
--- a/src/ptbench/data/mc_ch/fold_1.py
+++ b/src/ptbench/data/mc_ch/fold_1.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_1 import datamodule as mc_datamodule
 from ..shenzhen.fold_1 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_1_rgb.py b/src/ptbench/data/mc_ch/fold_1_rgb.py
index 6cfbcb36..25709510 100644
--- a/src/ptbench/data/mc_ch/fold_1_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_1_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_1_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_2.py b/src/ptbench/data/mc_ch/fold_2.py
index 1d4ac5b5..e3ac99ec 100644
--- a/src/ptbench/data/mc_ch/fold_2.py
+++ b/src/ptbench/data/mc_ch/fold_2.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_2 import datamodule as mc_datamodule
 from ..shenzhen.fold_2 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_2_rgb.py b/src/ptbench/data/mc_ch/fold_2_rgb.py
index eec98dca..1bbce20d 100644
--- a/src/ptbench/data/mc_ch/fold_2_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_2_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_2_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_2_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_3.py b/src/ptbench/data/mc_ch/fold_3.py
index b97b5e94..ed58cac7 100644
--- a/src/ptbench/data/mc_ch/fold_3.py
+++ b/src/ptbench/data/mc_ch/fold_3.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_3 import datamodule as mc_datamodule
 from ..shenzhen.fold_3 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_3_rgb.py b/src/ptbench/data/mc_ch/fold_3_rgb.py
index e380dc33..7d9401f0 100644
--- a/src/ptbench/data/mc_ch/fold_3_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_3_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_3_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_3_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_4.py b/src/ptbench/data/mc_ch/fold_4.py
index 8cd906b5..761d730e 100644
--- a/src/ptbench/data/mc_ch/fold_4.py
+++ b/src/ptbench/data/mc_ch/fold_4.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_4 import datamodule as mc_datamodule
 from ..shenzhen.fold_4 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_4_rgb.py b/src/ptbench/data/mc_ch/fold_4_rgb.py
index 7ba0ecfe..3f1a1476 100644
--- a/src/ptbench/data/mc_ch/fold_4_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_4_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_4_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_4_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_5.py b/src/ptbench/data/mc_ch/fold_5.py
index 3f20a33b..01a8f31e 100644
--- a/src/ptbench/data/mc_ch/fold_5.py
+++ b/src/ptbench/data/mc_ch/fold_5.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_5 import datamodule as mc_datamodule
 from ..shenzhen.fold_5 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_5_rgb.py b/src/ptbench/data/mc_ch/fold_5_rgb.py
index 61159ed6..cf66986f 100644
--- a/src/ptbench/data/mc_ch/fold_5_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_5_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_5_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_5_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_6.py b/src/ptbench/data/mc_ch/fold_6.py
index 74137910..5b488da9 100644
--- a/src/ptbench/data/mc_ch/fold_6.py
+++ b/src/ptbench/data/mc_ch/fold_6.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_6 import datamodule as mc_datamodule
 from ..shenzhen.fold_6 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_6_rgb.py b/src/ptbench/data/mc_ch/fold_6_rgb.py
index 79abe09b..ecadb9bf 100644
--- a/src/ptbench/data/mc_ch/fold_6_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_6_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_6_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_6_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_7.py b/src/ptbench/data/mc_ch/fold_7.py
index a94621e6..6514fd1a 100644
--- a/src/ptbench/data/mc_ch/fold_7.py
+++ b/src/ptbench/data/mc_ch/fold_7.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_7 import datamodule as mc_datamodule
 from ..shenzhen.fold_7 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_7_rgb.py b/src/ptbench/data/mc_ch/fold_7_rgb.py
index 90b866e1..9123adf3 100644
--- a/src/ptbench/data/mc_ch/fold_7_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_7_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_7_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_7_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_8.py b/src/ptbench/data/mc_ch/fold_8.py
index aa52bc81..e6050636 100644
--- a/src/ptbench/data/mc_ch/fold_8.py
+++ b/src/ptbench/data/mc_ch/fold_8.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_8 import datamodule as mc_datamodule
 from ..shenzhen.fold_8 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_8_rgb.py b/src/ptbench/data/mc_ch/fold_8_rgb.py
index 3df1838d..4ff54baa 100644
--- a/src/ptbench/data/mc_ch/fold_8_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_8_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_8_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_8_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_9.py b/src/ptbench/data/mc_ch/fold_9.py
index 4bb4a5a3..cdd69ba7 100644
--- a/src/ptbench/data/mc_ch/fold_9.py
+++ b/src/ptbench/data/mc_ch/fold_9.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_9 import datamodule as mc_datamodule
 from ..shenzhen.fold_9 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_9_rgb.py b/src/ptbench/data/mc_ch/fold_9_rgb.py
index a07ffce4..465e6c3d 100644
--- a/src/ptbench/data/mc_ch/fold_9_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_9_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_9_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_9_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/rgb.py b/src/ptbench/data/mc_ch/rgb.py
index 05407eae..4eee3c86 100644
--- a/src/ptbench/data/mc_ch/rgb.py
+++ b/src/ptbench/data/mc_ch/rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.rgb import datamodule as mc_datamodule
 from ..shenzhen.rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
-- 
GitLab