From 944bba9822207e72fbb39d9ecb663ae97937cb8b Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Thu, 6 Jun 2024 11:06:30 +0200
Subject: [PATCH] [classification.saliency] Fix wrong batch size during
 generation

---
 .../libs/classification/scripts/saliency/completeness.py   | 7 ++-----
 .../libs/classification/scripts/saliency/generate.py       | 7 ++-----
 .../classification/scripts/saliency/interpretability.py    | 1 +
 3 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/src/mednet/libs/classification/scripts/saliency/completeness.py b/src/mednet/libs/classification/scripts/saliency/completeness.py
index fa2d243d..3ee81288 100644
--- a/src/mednet/libs/classification/scripts/saliency/completeness.py
+++ b/src/mednet/libs/classification/scripts/saliency/completeness.py
@@ -204,6 +204,7 @@ def completeness(
     import json
 
     from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.scripts.predict import setup_datamodule
     from mednet.libs.common.utils.checkpointer import (
         get_checkpoint_to_run_inference,
     )
@@ -222,11 +223,7 @@ def completeness(
     device_manager = DeviceManager(device)
 
     datamodule.cache_samples = cache_samples
-    datamodule.parallel = parallel
-    datamodule.model_transforms = model.model_transforms
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="predict")
+    setup_datamodule(datamodule, model, 1, parallel)
 
     if weight.is_dir():
         weight = get_checkpoint_to_run_inference(weight)
diff --git a/src/mednet/libs/classification/scripts/saliency/generate.py b/src/mednet/libs/classification/scripts/saliency/generate.py
index 3dfd8d84..825c7b3e 100644
--- a/src/mednet/libs/classification/scripts/saliency/generate.py
+++ b/src/mednet/libs/classification/scripts/saliency/generate.py
@@ -170,6 +170,7 @@ def generate(
     """
 
     from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.scripts.predict import setup_datamodule
     from mednet.libs.common.utils.checkpointer import (
         get_checkpoint_to_run_inference,
     )
@@ -182,12 +183,8 @@ def generate(
     device_manager = DeviceManager(device)
 
     datamodule.cache_samples = cache_samples
-    datamodule.parallel = parallel
     datamodule.model_transforms = model.model_transforms
-    datamodule.batch_size = 1
-
-    datamodule.prepare_data()
-    datamodule.setup(stage="predict")
+    setup_datamodule(datamodule, model, 1, parallel)
 
     if weight.is_dir():
         weight = get_checkpoint_to_run_inference(weight)
diff --git a/src/mednet/libs/classification/scripts/saliency/interpretability.py b/src/mednet/libs/classification/scripts/saliency/interpretability.py
index c4b2ba52..0d0f556b 100644
--- a/src/mednet/libs/classification/scripts/saliency/interpretability.py
+++ b/src/mednet/libs/classification/scripts/saliency/interpretability.py
@@ -123,6 +123,7 @@ def interpretability(
 
     from ...engine.saliency.interpretability import run
 
+    datamodule.batch_size = 1
     datamodule.model_transforms = model.transforms
     datamodule.prepare_data()
     datamodule.setup(stage="predict")
-- 
GitLab