diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index fb0970f0b2c1ebb1cfcae291abd626428f556bf6..5e656d428c6ebe64dde7858f368149481b8e1c2e 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -92,3 +92,16 @@ class BaseDataModule(pl.LightningDataModule):
             shuffle=False,
             pin_memory=self.pin_memory,
         )
+
+
+def get_dataset_from_module(module, stage, **module_args):
+    """Instantiates a DataModule and retrieves the corresponding dataset.
+
+    Useful when combining multiple datasets.
+    """
+    module_instance = module(**module_args)
+    module_instance.prepare_data()
+    module_instance.setup(stage=stage)
+    dataset = module_instance.dataset
+
+    return dataset
diff --git a/src/ptbench/data/mc_ch/default.py b/src/ptbench/data/mc_ch/default.py
index 0af1d68682ec2b91ae8febf1adcebf22704f0982..cf901fdb41b14037c1827d364fedee1f70b1a665 100644
--- a/src/ptbench/data/mc_ch/default.py
+++ b/src/ptbench/data/mc_ch/default.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.default import datamodule as mc_datamodule
 from ..shenzhen.default import datamodule as ch_datamodule
 
@@ -37,27 +37,16 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
 
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_0.py b/src/ptbench/data/mc_ch/fold_0.py
index e151c0e0eb0703b8a8b3bc83f68d8f020053cc14..eeb8f5128f56dc41f383ebd1d58b67b8cfd5ae27 100644
--- a/src/ptbench/data/mc_ch/fold_0.py
+++ b/src/ptbench/data/mc_ch/fold_0.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_0 import datamodule as mc_datamodule
 from ..shenzhen.fold_0 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_0_rgb.py b/src/ptbench/data/mc_ch/fold_0_rgb.py
index 56502b24651879a8106ddd6228eee70e1c4de205..6c8c5aebce6a1b972a8f418a66ff2d7cbdc5ae30 100644
--- a/src/ptbench/data/mc_ch/fold_0_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_0_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_0_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_0_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_1.py b/src/ptbench/data/mc_ch/fold_1.py
index 732513da40421d271b86a087731a56f6e1bafbd5..b6bcb0f70e15097da82a2b8f824ee7178e5f1972 100644
--- a/src/ptbench/data/mc_ch/fold_1.py
+++ b/src/ptbench/data/mc_ch/fold_1.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_1 import datamodule as mc_datamodule
 from ..shenzhen.fold_1 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_1_rgb.py b/src/ptbench/data/mc_ch/fold_1_rgb.py
index 6cfbcb366f71146cc7f20c5d27ea1998cd209bc1..2570951009386568ac10cf4a7de705c54ea71a9b 100644
--- a/src/ptbench/data/mc_ch/fold_1_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_1_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_1_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_2.py b/src/ptbench/data/mc_ch/fold_2.py
index 1d4ac5b58bb7d2500932061be0e9a252139e6274..e3ac99ece5ded748c6d1cd772846986e3ff1d9d0 100644
--- a/src/ptbench/data/mc_ch/fold_2.py
+++ b/src/ptbench/data/mc_ch/fold_2.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_2 import datamodule as mc_datamodule
 from ..shenzhen.fold_2 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_2_rgb.py b/src/ptbench/data/mc_ch/fold_2_rgb.py
index eec98dca20441bf30fb6ae8250303dd970b68148..1bbce20d34d7d2083fc4fb35effd6022b6028b28 100644
--- a/src/ptbench/data/mc_ch/fold_2_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_2_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_2_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_2_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_3.py b/src/ptbench/data/mc_ch/fold_3.py
index b97b5e944b94797c518a56b8f0021b5b890766ec..ed58cac744caa2e7a30378c90c967d2406a30ef5 100644
--- a/src/ptbench/data/mc_ch/fold_3.py
+++ b/src/ptbench/data/mc_ch/fold_3.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_3 import datamodule as mc_datamodule
 from ..shenzhen.fold_3 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_3_rgb.py b/src/ptbench/data/mc_ch/fold_3_rgb.py
index e380dc33b3653632a3c8f7f5324ab2b0bfb737e2..7d9401f0ece0f316c958cebb2eed35f5ec0cabbc 100644
--- a/src/ptbench/data/mc_ch/fold_3_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_3_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_3_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_3_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_4.py b/src/ptbench/data/mc_ch/fold_4.py
index 8cd906b53064644499f45a07f082e2f91acb76fc..761d730e0c4a00c230ec47be966d52bfce72c3bc 100644
--- a/src/ptbench/data/mc_ch/fold_4.py
+++ b/src/ptbench/data/mc_ch/fold_4.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_4 import datamodule as mc_datamodule
 from ..shenzhen.fold_4 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_4_rgb.py b/src/ptbench/data/mc_ch/fold_4_rgb.py
index 7ba0ecfebd957bf6778fd0c34731974ddec50434..3f1a147634fa41d06299765ec44aab2d1271e429 100644
--- a/src/ptbench/data/mc_ch/fold_4_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_4_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_4_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_4_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_5.py b/src/ptbench/data/mc_ch/fold_5.py
index 3f20a33b633df0aba8b65e5849709634b96cce6e..01a8f31ebca9871ce5e5cee8e5c9106c86335362 100644
--- a/src/ptbench/data/mc_ch/fold_5.py
+++ b/src/ptbench/data/mc_ch/fold_5.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_5 import datamodule as mc_datamodule
 from ..shenzhen.fold_5 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_5_rgb.py b/src/ptbench/data/mc_ch/fold_5_rgb.py
index 61159ed6e37be4e898ac37d3d2b1cf4f71b5347a..cf66986fe9b48119f0a5510ad6d7547513542346 100644
--- a/src/ptbench/data/mc_ch/fold_5_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_5_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_5_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_5_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_6.py b/src/ptbench/data/mc_ch/fold_6.py
index 7413791054f2d14948f3ee4f5f12817e172872a8..5b488da973f00e5a6718eac6117016ccfe3a359f 100644
--- a/src/ptbench/data/mc_ch/fold_6.py
+++ b/src/ptbench/data/mc_ch/fold_6.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_6 import datamodule as mc_datamodule
 from ..shenzhen.fold_6 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_6_rgb.py b/src/ptbench/data/mc_ch/fold_6_rgb.py
index 79abe09b973fe27a54af0cd1db6d855f1e93d2f4..ecadb9bfabfd1348281c63eb97339c5bad990f6a 100644
--- a/src/ptbench/data/mc_ch/fold_6_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_6_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_6_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_6_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_7.py b/src/ptbench/data/mc_ch/fold_7.py
index a94621e61d846e596ff8422c8abce39c896ee028..6514fd1aba701fff1dcaaaf4791d6ae9b861de22 100644
--- a/src/ptbench/data/mc_ch/fold_7.py
+++ b/src/ptbench/data/mc_ch/fold_7.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_7 import datamodule as mc_datamodule
 from ..shenzhen.fold_7 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_7_rgb.py b/src/ptbench/data/mc_ch/fold_7_rgb.py
index 90b866e1d0b506e3d5d03a76472815fb6165294f..9123adf3676c08acbd7d32a1b2747ffd314dae38 100644
--- a/src/ptbench/data/mc_ch/fold_7_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_7_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_7_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_7_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_8.py b/src/ptbench/data/mc_ch/fold_8.py
index aa52bc818fd9dc4daff55c6691ada122813fbfd8..e60506363bf508a39f3bb2c9806f8bd3bcc9e83c 100644
--- a/src/ptbench/data/mc_ch/fold_8.py
+++ b/src/ptbench/data/mc_ch/fold_8.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_8 import datamodule as mc_datamodule
 from ..shenzhen.fold_8 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_8_rgb.py b/src/ptbench/data/mc_ch/fold_8_rgb.py
index 3df1838d21c7c8c3cdb4fb686edad87f053e564f..4ff54baa2f5424993ed747edd4ba136da9324981 100644
--- a/src/ptbench/data/mc_ch/fold_8_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_8_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_8_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_8_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_9.py b/src/ptbench/data/mc_ch/fold_9.py
index 4bb4a5a3c9b70a35d4cd7665bc5221ae3948419e..cdd69ba7deeeb2759cdabc1134379145d5537696 100644
--- a/src/ptbench/data/mc_ch/fold_9.py
+++ b/src/ptbench/data/mc_ch/fold_9.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_9 import datamodule as mc_datamodule
 from ..shenzhen.fold_9 import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/fold_9_rgb.py b/src/ptbench/data/mc_ch/fold_9_rgb.py
index a07ffce4a0b6d567e85eff06162a845a3f4c6bab..465e6c3d3d2c65eb367f7a97d7bb296a3be3dab1 100644
--- a/src/ptbench/data/mc_ch/fold_9_rgb.py
+++ b/src/ptbench/data/mc_ch/fold_9_rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.fold_9_rgb import datamodule as mc_datamodule
 from ..shenzhen.fold_9_rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}
diff --git a/src/ptbench/data/mc_ch/rgb.py b/src/ptbench/data/mc_ch/rgb.py
index 05407eae2ee826f5706770bbec3feb2e6aa9426f..4eee3c86f6150859a45ecd0e5e9d22dc45adc637 100644
--- a/src/ptbench/data/mc_ch/rgb.py
+++ b/src/ptbench/data/mc_ch/rgb.py
@@ -8,7 +8,7 @@ from clapper.logging import setup
 from torch.utils.data.dataset import ConcatDataset
 
 from .. import return_subsets
-from ..base_datamodule import BaseDataModule
+from ..base_datamodule import BaseDataModule, get_dataset_from_module
 from ..montgomery.rgb import datamodule as mc_datamodule
 from ..shenzhen.rgb import datamodule as ch_datamodule
 
@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
 
     def setup(self, stage: str):
         # Instantiate other datamodules and get their datasets
-        mc_module = mc_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
-
-        mc_module.prepare_data()
-        mc_module.setup(stage=stage)
-        mc = mc_module.dataset
-
-        ch_module = ch_datamodule(
-            train_batch_size=self.train_batch_size,
-            predict_batch_size=self.predict_batch_size,
-            drop_incomplete_batch=self.drop_incomplete_batch,
-            multiproc_kwargs=self.multiproc_kwargs,
-        )
+        module_args = {
+            "train_batch_size": self.train_batch_size,
+            "predict_batch_size": self.predict_batch_size,
+            "drop_incomplete_batch": self.drop_incomplete_batch,
+            "multiproc_kwargs": self.multiproc_kwargs,
+        }
 
-        ch_module.prepare_data()
-        ch_module.setup(stage=stage)
-        ch = ch_module.dataset
+        mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
+        ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
 
         # Combine datasets
         self.dataset = {}