diff --git a/src/mednet/libs/classification/scripts/saliency/completeness.py b/src/mednet/libs/classification/scripts/saliency/completeness.py
index fa2d243d6287df127f6da33cb0a31f7ed2b405f6..3ee81288101676fff4714adcfb08d1f41c20c378 100644
--- a/src/mednet/libs/classification/scripts/saliency/completeness.py
+++ b/src/mednet/libs/classification/scripts/saliency/completeness.py
@@ -204,6 +204,7 @@ def completeness(
     import json
 
     from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.scripts.predict import setup_datamodule
     from mednet.libs.common.utils.checkpointer import (
         get_checkpoint_to_run_inference,
     )
@@ -222,11 +223,7 @@ def completeness(
     device_manager = DeviceManager(device)
 
     datamodule.cache_samples = cache_samples
-    datamodule.parallel = parallel
-    datamodule.model_transforms = model.model_transforms
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="predict")
+    setup_datamodule(datamodule, model, 1, parallel)
 
     if weight.is_dir():
         weight = get_checkpoint_to_run_inference(weight)
diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py
index 3dfd8d84d29c3142f9081e4d6ffe56a38ffa9726..825c7b3e2d22cf8b487143606420e1c050c71454 100644
--- a/src/mednet/libs/classification/scripts/saliency/generate.py
+++ b/src/mednet/libs/classification/scripts/saliency/generate.py
@@ -170,6 +170,7 @@ def generate(
     """
 
     from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.scripts.predict import setup_datamodule
     from mednet.libs.common.utils.checkpointer import (
         get_checkpoint_to_run_inference,
     )
@@ -182,12 +183,8 @@ def generate(
     device_manager = DeviceManager(device)
 
     datamodule.cache_samples = cache_samples
-    datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
-    datamodule.batch_size = 1
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="predict")
+    setup_datamodule(datamodule, model, 1, parallel)
 
     if weight.is_dir():
         weight = get_checkpoint_to_run_inference(weight)
diff --git a/src/mednet/libs/classification/scripts/saliency/interpretability.py b/src/mednet/libs/classification/scripts/saliency/interpretability.py
index c4b2ba525ed720ac1c8bc4fc2651cc2c7e6fb3d4..0d0f556b49e6b6dd8f0a4c9d68adbfa9d5003192 100644
--- a/src/mednet/libs/classification/scripts/saliency/interpretability.py
+++ b/src/mednet/libs/classification/scripts/saliency/interpretability.py
@@ -123,6 +123,7 @@ def interpretability(
 
     from ...engine.saliency.interpretability import run
 
+    datamodule.batch_size = 1
     datamodule.model_transforms = model.transforms
     datamodule.prepare_data()
     datamodule.setup(stage="predict")