From 143cf210a3a2c406bbd91809439b98d8431a356c Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 7 Feb 2024 17:10:17 +0100 Subject: [PATCH] [saliency.interpretability] Apply model transforms to datamodule --- src/mednet/scripts/saliency/interpretability.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py index 0d1b84f7..14a99040 100644 --- a/src/mednet/scripts/saliency/interpretability.py +++ b/src/mednet/scripts/saliency/interpretability.py @@ -23,10 +23,19 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - mednet saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json + mednet saliency interpretability -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json """, ) +@click.option( + "--model", + "-m", + help="""A lightning module instance implementing the network architecture + (not the weights, necessarily) to be used for inference. Currently, only + supports pasa and densenet models.""", + required=True, + cls=ResourceOption, +) @click.option( "--datamodule", "-d", @@ -78,6 +87,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def interpretability( + model, datamodule, input_folder, target_label, @@ -114,7 +124,7 @@ def interpretability( from ...engine.saliency.interpretability import run - datamodule.model_transforms = [] + datamodule.model_transforms = model.transforms datamodule.prepare_data() datamodule.setup(stage="predict") -- GitLab