diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
index 6dc59200b2c751d7e304393a9b28abee18d40fa5..967d40e9358c8f946f876199cac56a921a89606f 100644
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -38,6 +38,7 @@ class DefaultModule(BaseDataModule):
 
         self.cache_samples = cache_samples
         self.has_setup_fit = False
+        self.has_setup_predict = False
 
     def setup(self, stage: str):
         if self.cache_samples:
@@ -51,7 +52,7 @@ class DefaultModule(BaseDataModule):
             )
             samples_loader = _delayed_loader
 
-        self.json_dataset = JSONDataset(
+        json_dataset = JSONDataset(
             protocols=_protocols,
             fieldnames=("data", "label"),
             loader=samples_loader,
@@ -62,8 +63,17 @@ class DefaultModule(BaseDataModule):
                 self.train_dataset,
                 self.validation_dataset,
                 self.extra_validation_datasets,
-            ) = return_subsets(self.json_dataset, "default", stage)
+            ) = return_subsets(json_dataset, "default", stage)
             self.has_setup_fit = True
 
+        if not self.has_setup_predict and stage == "predict":
+            (
+                self.train_dataset,
+                self.validation_dataset,
+                self.extra_validation_datasets,
+            ) = return_subsets(json_dataset, "default", stage)
+
+            self.has_setup_predict = True
+
 
-datamodule = DefaultModule
+datamodule = DefaultModule()
diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index 5e656d428c6ebe64dde7858f368149481b8e1c2e..1c51a1055dc226efc72292a3e439041c67a6d37b 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -18,13 +18,15 @@ class BaseDataModule(pl.LightningDataModule):
         self,
         train_batch_size=1,
         predict_batch_size=1,
+        batch_chunk_count=1,
         drop_incomplete_batch=False,
-        multiproc_kwargs=None,
+        multiproc_kwargs={},
     ):
         super().__init__()
 
         self.train_batch_size = train_batch_size
         self.predict_batch_size = predict_batch_size
+        self.batch_chunk_count = batch_chunk_count
 
         self.drop_incomplete_batch = drop_incomplete_batch
         self.pin_memory = (
@@ -47,7 +49,7 @@ class BaseDataModule(pl.LightningDataModule):
 
         return DataLoader(
             self.train_dataset,
-            batch_size=self.train_batch_size,
+            batch_size=self.compute_chunk_size(self.train_batch_size),
             drop_last=self.drop_incomplete_batch,
             pin_memory=self.pin_memory,
             sampler=train_sampler,
@@ -59,7 +61,7 @@ class BaseDataModule(pl.LightningDataModule):
 
         val_loader = DataLoader(
             dataset=self.validation_dataset,
-            batch_size=self.train_batch_size,
+            batch_size=self.compute_chunk_size(self.train_batch_size),
             shuffle=False,
             drop_last=False,
             pin_memory=self.pin_memory,
@@ -86,12 +88,26 @@ class BaseDataModule(pl.LightningDataModule):
         return loaders_dict
 
     def predict_dataloader(self):
-        return DataLoader(
-            dataset=self.predict_dataset,
-            batch_size=self.predict_batch_size,
-            shuffle=False,
-            pin_memory=self.pin_memory,
-        )
+        loaders_dict = {}
+
+        loaders_dict["train_dataloader"] = self.train_dataloader()
+        for k, v in self.val_dataloader().items():
+            loaders_dict[k] = v
+
+        return loaders_dict
+
+    def compute_chunk_size(self, batch_size):
+        batch_chunk_size = batch_size
+        if batch_size % self.batch_chunk_count != 0:
+            # batch_size must be divisible by batch_chunk_count.
+            raise RuntimeError(
+                f"--batch-size ({batch_size}) must be divisible by "
+                f"--batch-chunk-size ({self.batch_chunk_count})."
+            )
+        else:
+            batch_chunk_size = batch_size // self.batch_chunk_count
+
+        return batch_chunk_size
 
 
 def get_dataset_from_module(module, stage, **module_args):
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index c59c81a3bc8f325ad20d883d58b8fa81bd51bbb7..3606c9f9d2c0179f171acc4b3dcba7987242047e 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -288,22 +288,10 @@ def train(
             "multiprocessing_context"
         ] = multiprocessing.get_context("spawn")
 
-    batch_chunk_size = batch_size
-    if batch_size % batch_chunk_count != 0:
-        # batch_size must be divisible by batch_chunk_count.
-        raise RuntimeError(
-            f"--batch-size ({batch_size}) must be divisible by "
-            f"--batch-chunk-size ({batch_chunk_count})."
-        )
-    else:
-        batch_chunk_size = batch_size // batch_chunk_count
+    datamodule.train_batch_size = batch_size
+    datamodule.batch_chunk_count = batch_chunk_count
+    datamodule.multiproc_kwargs = multiproc_kwargs
 
-    datamodule = datamodule(
-        train_batch_size=batch_chunk_size,
-        drop_incomplete_batch=drop_incomplete_batch,
-        multiproc_kwargs=multiproc_kwargs,
-        cache_samples=cache_samples,
-    )
     # Manually calling these as we need to access some values to reweight the criterion
     datamodule.prepare_data()
     datamodule.setup(stage="fit")