diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py index 0d1b84f7692a1c218cca603d7e9306b617a8d1c5..14a99040b87798863d2d97d69b97b46d14bceaae 100644 --- a/src/mednet/scripts/saliency/interpretability.py +++ b/src/mednet/scripts/saliency/interpretability.py @@ -23,10 +23,19 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - mednet saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json + mednet saliency interpretability -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json """, ) +@click.option( + "--model", + "-m", + help="""A lightning module instance implementing the network architecture + (not the weights, necessarily) to be used for inference. Currently, only + supports pasa and densenet models.""", + required=True, + cls=ResourceOption, +) @click.option( "--datamodule", "-d", @@ -78,6 +87,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def interpretability( + model, datamodule, input_folder, target_label, @@ -114,7 +124,7 @@ def interpretability( from ...engine.saliency.interpretability import run - datamodule.model_transforms = [] + datamodule.model_transforms = model.transforms datamodule.prepare_data() datamodule.setup(stage="predict")