diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index 5dfbbd7e6e0b00df0d1a6208a7bb70351b1cb417..3bcd441a59a8e08c2e4d9aa33808c8d651026100 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -109,7 +109,7 @@ class BaseDataModule(pl.LightningDataModule):
     def predict_dataloader(self):
         loaders_dict = {}
 
-        loaders_dict["train_dataloader"] = self.train_dataloader()
+        loaders_dict["train_loader"] = self.train_dataloader()
         for k, v in self.val_dataloader().items():
             loaders_dict[k] = v
 
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index c78d52ea806310c087ecda943b269ec4129a4d11..d0ac43f98e21b8ce6803797d6a1fde38c6302660 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -1,4 +1,5 @@
 import csv
+import os
 import time
 
 from collections import defaultdict
@@ -141,24 +142,34 @@ class LoggingCallback(Callback):
 class PredictionsWriter(BasePredictionWriter):
     """Lightning callback to write predictions to a file."""
 
-    def __init__(self, logfile_name, logfile_fields, write_interval):
+    def __init__(self, output_dir, logfile_fields, write_interval):
         super().__init__(write_interval)
-        self.logfile_name = logfile_name
+        self.output_dir = output_dir
         self.logfile_fields = logfile_fields
 
     def write_on_epoch_end(
         self, trainer, pl_module, predictions, batch_indices
     ):
-        with open(self.logfile_name, "w") as logfile:
-            logwriter = csv.DictWriter(logfile, fieldnames=self.logfile_fields)
-            logwriter.writeheader()
-
-            for prediction in predictions:
-                logwriter.writerow(
-                    {
-                        "filename": prediction[0],
-                        "likelihood": prediction[1].numpy(),
-                        "ground_truth": prediction[2].numpy(),
-                    }
-                )
-            logfile.flush()
+        for dataloader_idx, dataloader_results in enumerate(predictions):
+            dataloader_name = list(
+                trainer.datamodule.predict_dataloader().keys()
+            )[dataloader_idx].replace("_loader", "")
+
+            logfile = os.path.join(
+                self.output_dir, dataloader_name, "predictions.csv"
+            )
+            os.makedirs(os.path.dirname(logfile), exist_ok=True)
+
+            with open(logfile, "w") as l_f:
+                logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields)
+                logwriter.writeheader()
+
+                for prediction in dataloader_results:
+                    logwriter.writerow(
+                        {
+                            "filename": prediction[0],
+                            "likelihood": prediction[1].numpy(),
+                            "ground_truth": prediction[2].numpy(),
+                        }
+                    )
+                l_f.flush()
diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py
index 4535b368bca54daf97fa350b91516a126eff319a..5dcbb79c9fd0a8c32f9d269f8302b888da56be84 100644
--- a/src/ptbench/engine/predictor.py
+++ b/src/ptbench/engine/predictor.py
@@ -13,7 +13,7 @@ from .callbacks import PredictionsWriter
 logger = logging.getLogger(__name__)
 
 
-def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
+def run(model, datamodule, accelerator, output_folder, grad_cams=False):
     """Runs inference on input data, outputs csv files with predictions.
 
     Parameters
@@ -24,10 +24,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
     data_loader : py:class:`torch.torch.utils.data.DataLoader`
         The pytorch Dataloader used to iterate over batches.
 
-    name : str
-        The local name of this dataset (e.g. ``train``, or ``test``), to be
-        used when saving measures files.
-
     accelerator : str
         A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
 
@@ -44,7 +40,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
     all_predictions : list
         All the predictions associated with filename and ground truth.
     """
-    output_folder = os.path.join(output_folder, name)
 
     logger.info(f"Output folder: {output_folder}")
     os.makedirs(output_folder, exist_ok=True)
@@ -58,7 +53,6 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
 
     logger.info(f"Device: {devices}")
 
-    logfile_name = os.path.join(output_folder, "predictions.csv")
     logfile_fields = ("filename", "likelihood", "ground_truth")
 
     trainer = Trainer(
@@ -66,7 +60,7 @@ def run(model, datamodule, name, accelerator, output_folder, grad_cams=False):
         devices=devices,
         callbacks=[
             PredictionsWriter(
-                logfile_name=logfile_name,
+                output_dir=output_folder,
                 logfile_fields=logfile_fields,
                 write_interval="epoch",
             ),
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index d9e86b7055ed1ca3a06d807ab5d02b175b2f0cef..9da9702f31436df441cdb56d9415cc06e6829623 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -169,8 +169,9 @@ class PASA(pl.LightningModule):
             return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
 
     def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
-        names = batch["names"]
+        names = batch["name"]
         images = batch["data"]
+        labels = batch["label"]
 
         outputs = self(images)
         probabilities = torch.sigmoid(outputs)
@@ -180,18 +181,34 @@ class PASA(pl.LightningModule):
         if isinstance(outputs, list):
             outputs = outputs[-1]
 
-        return {
-            f"dataloader_{dataloader_idx}_predictions": (
-                names[0],
-                torch.flatten(probabilities),
-                torch.flatten(batch[2]),
-            )
-        }
-
-    def on_predict_epoch_end(self):
-        # Need to cache predictions in the predict step, then reorder by key
-        # Clear prediction dict
-        raise NotImplementedError
+        results = (
+            names[0],
+            torch.flatten(probabilities),
+            torch.flatten(labels),
+        )
+
+        return results
+        # {
+        # f"dataloader_{dataloader_idx}_predictions": (
+        #    names[0],
+        #    torch.flatten(probabilities),
+        #    torch.flatten(labels),
+        # )
+        # }
+
+    # def on_predict_epoch_end(self):
+
+    #    retval = defaultdict(list)
+
+    #    for dataloader_name, predictions in self.predictions_cache.items():
+    #        for prediction in predictions:
+    #            retval[dataloader_name]["name"].append(prediction[0])
+    #            retval[dataloader_name]["prediction"].append(prediction[1])
+    #            retval[dataloader_name]["label"].append(prediction[2])
+
+    # Need to cache predictions in the predict step, then reorder by key
+    # Clear prediction dict
+    # raise NotImplementedError
 
     def configure_optimizers(self):
         # Dynamically instantiates the optimizer given the configs
diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py
index 52dc98f540247e0743608f53b5496f0d1ffbf89e..a78d74b41d8f75b3f4466a890f206e4b2503a84c 100644
--- a/src/ptbench/scripts/predict.py
+++ b/src/ptbench/scripts/predict.py
@@ -41,7 +41,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--dataset",
+    "--datamodule",
     "-d",
     help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
     "to be used for running prediction, possibly including all pre-processing "
@@ -77,14 +77,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--relevance-analysis",
-    "-r",
-    help="If set, generate relevance analysis pdfs to indicate the relative"
-    "importance of each feature",
-    is_flag=True,
-    cls=ResourceOption,
-)
 @click.option(
     "--grad-cams",
     "-g",
@@ -96,32 +88,27 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 def predict(
     output_folder,
     model,
-    dataset,
+    datamodule,
     batch_size,
     accelerator,
     weight,
-    relevance_analysis,
     grad_cams,
     **_,
 ) -> None:
     """Predicts Tuberculosis presence (probabilities) on input images."""
 
-    import copy
     import os
-    import shutil
 
     import numpy as np
-    import torch
 
     from matplotlib.backends.backend_pdf import PdfPages
-    from sklearn import metrics
-    from torch.utils.data import ConcatDataset, DataLoader
 
-    from ..data.datamodule import DataModule
     from ..engine.predictor import run
     from ..utils.plot import relevance_analysis_plot
 
-    dataset = dataset if isinstance(dataset, dict) else dict(test=dataset)
+    datamodule = datamodule(
+        batch_size=batch_size,
+    )
 
     logger.info(f"Loading checkpoint from {weight}")
     model = model.load_from_checkpoint(weight, strict=False)
@@ -141,83 +128,4 @@ def predict(
         )
         pdf.close()
 
-    for k, v in dataset.items():
-        if k.startswith("_"):
-            logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
-            continue
-
-        logger.info(f"Running inference on '{k}' set...")
-
-        datamodule = DataModule(
-            v,
-            train_batch_size=batch_size,
-        )
-
-        predictions = run(
-            model, datamodule, k, accelerator, output_folder, grad_cams
-        )
-
-        # Relevance analysis using permutation feature importance
-        if relevance_analysis:
-            if isinstance(v, ConcatDataset) or not isinstance(
-                v._samples[0].data["data"], list
-            ):
-                logger.info(
-                    "Relevance analysis only possible with radiological signs as input. Cancelling..."
-                )
-                continue
-
-            nb_features = len(v._samples[0].data["data"])
-
-            if nb_features == 1:
-                logger.info("Relevance analysis not possible with one feature")
-            else:
-                logger.info(f"Starting relevance analysis for subset '{k}'...")
-
-                all_mse = []
-                for f in range(nb_features):
-                    v_original = copy.deepcopy(v)
-
-                    # Randomly permute feature values from all samples
-                    v.random_permute(f)
-
-                    data_loader = DataLoader(
-                        dataset=v,
-                        batch_size=batch_size,
-                        shuffle=False,
-                        pin_memory=torch.cuda.is_available(),
-                    )
-
-                    predictions_with_mean = run(
-                        model,
-                        data_loader,
-                        k,
-                        accelerator,
-                        output_folder + "_temp",
-                    )
-
-                    # Compute MSE between original and new predictions
-                    all_mse.append(
-                        metrics.mean_squared_error(
-                            np.array(predictions, dtype=object)[:, 1],
-                            np.array(predictions_with_mean, dtype=object)[:, 1],
-                        )
-                    )
-
-                    # Back to original values
-                    v = v_original
-
-                # Remove temporary folder
-                shutil.rmtree(output_folder + "_temp", ignore_errors=True)
-
-                filepath = os.path.join(output_folder, k + "_RA.pdf")
-                logger.info(f"Creating and saving plot at {filepath}...")
-                os.makedirs(os.path.dirname(filepath), exist_ok=True)
-                pdf = PdfPages(filepath)
-                pdf.savefig(
-                    relevance_analysis_plot(
-                        all_mse,
-                        title=k.capitalize() + " set relevance analysis",
-                    )
-                )
-                pdf.close()
+    _ = run(model, datamodule, accelerator, output_folder, grad_cams)
diff --git a/src/ptbench/scripts/relevance_analysis.py b/src/ptbench/scripts/relevance_analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..1771c107f1317ad181838cf408a76d2c955ad30c
--- /dev/null
+++ b/src/ptbench/scripts/relevance_analysis.py
@@ -0,0 +1,84 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Import copy import os import shutil.
+
+import numpy as np
+import torch
+
+from matplotlib.backends.backend_pdf import PdfPages
+from sklearn import metrics
+from torch.utils.data import ConcatDataset, DataLoader
+
+from ..engine.predictor import run
+from ..utils.plot import relevance_analysis_plot
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+# Relevance analysis using permutation feature importance
+if relevance_analysis:
+    if isinstance(v, ConcatDataset) or not isinstance(
+        v._samples[0].data["data"], list
+    ):
+        logger.info(
+            "Relevance analysis only possible with radiological signs as input. Cancelling..."
+        )
+        continue
+
+    nb_features = len(v._samples[0].data["data"])
+
+    if nb_features == 1:
+        logger.info("Relevance analysis not possible with one feature")
+    else:
+        logger.info(f"Starting relevance analysis for subset '{k}'...")
+
+        all_mse = []
+        for f in range(nb_features):
+            v_original = copy.deepcopy(v)
+
+            # Randomly permute feature values from all samples
+            v.random_permute(f)
+
+            data_loader = DataLoader(
+                dataset=v,
+                batch_size=batch_size,
+                shuffle=False,
+                pin_memory=torch.cuda.is_available(),
+            )
+
+            predictions_with_mean = run(
+                model,
+                data_loader,
+                k,
+                accelerator,
+                output_folder + "_temp",
+            )
+
+            # Compute MSE between original and new predictions
+            all_mse.append(
+                metrics.mean_squared_error(
+                    np.array(predictions, dtype=object)[:, 1],
+                    np.array(predictions_with_mean, dtype=object)[:, 1],
+                )
+            )
+
+            # Back to original values
+            v = v_original
+
+        # Remove temporary folder
+        shutil.rmtree(output_folder + "_temp", ignore_errors=True)
+
+        filepath = os.path.join(output_folder, k + "_RA.pdf")
+        logger.info(f"Creating and saving plot at {filepath}...")
+        os.makedirs(os.path.dirname(filepath), exist_ok=True)
+        pdf = PdfPages(filepath)
+        pdf.savefig(
+            relevance_analysis_plot(
+                all_mse,
+                title=k.capitalize() + " set relevance analysis",
+            )
+        )
+        pdf.close()
+"""