diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index efbcfaf9e930025a5c0583221d2b8626eb033e3e..d3c03d76578b624f4d38a06e255501f91e2fab49 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -67,17 +67,7 @@ class DataModule(pl.LightningDataModule):
                 self.extra_validation_datasets = None
 
         if stage == "predict":
-            self.predict_dataset = []
-
-            for split_key in self.dataset.keys():
-                if split_key.startswith("_"):
-                    logger.info(
-                        f"Skipping dataset '{split_key}' (not to be evaluated)"
-                    )
-                    continue
-
-                else:
-                    self.predict_dataset.append(self.dataset[split_key])
+            self.predict_dataset = self.dataset
 
     def train_dataloader(self):
         train_samples_weights = get_samples_weights(self.train_dataset)
@@ -127,14 +117,9 @@ class DataModule(pl.LightningDataModule):
         return loaders_dict
 
     def predict_dataloader(self):
-        loaders_dict = {}
-
-        for set_idx, pred_set in enumerate(self.predict_dataset):
-            loaders_dict[set_idx] = DataLoader(
-                dataset=pred_set,
-                batch_size=self.predict_batch_size,
-                shuffle=False,
-                pin_memory=self.pin_memory,
-            )
-
-        return loaders_dict
+        return DataLoader(
+            dataset=self.predict_dataset,
+            batch_size=self.predict_batch_size,
+            shuffle=False,
+            pin_memory=self.pin_memory,
+        )
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 3613e75ad03c7f64ce9ce2c81e013d251fdf2858..52dc98f540247e0743608f53b5496f0d1ffbf89e 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -117,6 +117,7 @@ def predict(
     from sklearn import metrics
     from torch.utils.data import ConcatDataset, DataLoader
 
+    from ..data.datamodule import DataModule
     from ..engine.predictor import run
     from ..utils.plot import relevance_analysis_plot
 
@@ -147,14 +148,13 @@ def predict(
 
         logger.info(f"Running inference on '{k}' set...")
 
-        data_loader = DataLoader(
-            dataset=v,
-            batch_size=batch_size,
-            shuffle=False,
-            pin_memory=torch.cuda.is_available(),
+        datamodule = DataModule(
+            v,
+            train_batch_size=batch_size,
         )
+
         predictions = run(
-            model, data_loader, k, accelerator, output_folder, grad_cams
+            model, datamodule, k, accelerator, output_folder, grad_cams
         )
 
         # Relevance analysis using permutation feature importance