From 143cf210a3a2c406bbd91809439b98d8431a356c Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 7 Feb 2024 17:10:17 +0100
Subject: [PATCH] [saliency.interpretability] Apply model transforms to
 datamodule

---
 src/mednet/scripts/saliency/interpretability.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py
index 0d1b84f7..14a99040 100644
--- a/src/mednet/scripts/saliency/interpretability.py
+++ b/src/mednet/scripts/saliency/interpretability.py
@@ -23,10 +23,19 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
    .. code:: sh
 
-      mednet saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json
+      mednet saliency interpretability -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json
 
 """,
 )
+@click.option(
+    "--model",
+    "-m",
+    help="""A lightning module instance implementing the network architecture
+    (not the weights, necessarily) to be used for inference.  Currently, only
+    supports pasa and densenet models.""",
+    required=True,
+    cls=ResourceOption,
+)
 @click.option(
     "--datamodule",
     "-d",
@@ -78,6 +87,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def interpretability(
+    model,
     datamodule,
     input_folder,
     target_label,
@@ -114,7 +124,7 @@ def interpretability(
 
     from ...engine.saliency.interpretability import run
 
-    datamodule.model_transforms = []
+    datamodule.model_transforms = model.transforms
     datamodule.prepare_data()
     datamodule.setup(stage="predict")
 
-- 
GitLab