diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py
index 0d1b84f7692a1c218cca603d7e9306b617a8d1c5..14a99040b87798863d2d97d69b97b46d14bceaae 100644
--- a/src/mednet/scripts/saliency/interpretability.py
+++ b/src/mednet/scripts/saliency/interpretability.py
@@ -23,10 +23,19 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
    .. code:: sh
 
-      mednet saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json
+      mednet saliency interpretability -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json
 
 """,
 )
+@click.option(
+    "--model",
+    "-m",
+    help="""A lightning module instance implementing the network architecture
+    (not the weights, necessarily) to be used for inference.  Currently, only
+    supports pasa and densenet models.""",
+    required=True,
+    cls=ResourceOption,
+)
 @click.option(
     "--datamodule",
     "-d",
@@ -78,6 +87,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 )
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def interpretability(
+    model,
     datamodule,
     input_folder,
     target_label,
@@ -114,7 +124,7 @@ def interpretability(
 
     from ...engine.saliency.interpretability import run
 
-    datamodule.model_transforms = []
+    datamodule.model_transforms = model.transforms
     datamodule.prepare_data()
     datamodule.setup(stage="predict")