From 953abf6210019a6536c93055c160924fbb249d32 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 15 Dec 2023 14:04:58 +0100
Subject: [PATCH] [scripts.saliency_interpretability] Allow user to explicitly
 define target to be analysed

---
 src/ptbench/scripts/saliency_interpretability.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/src/ptbench/scripts/saliency_interpretability.py b/src/ptbench/scripts/saliency_interpretability.py
index 4d7599a8..bca56565 100644
--- a/src/ptbench/scripts/saliency_interpretability.py
+++ b/src/ptbench/scripts/saliency_interpretability.py
@@ -52,6 +52,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     default="saliency-maps",
     cls=ResourceOption,
 )
+@click.option(
+    "--target-label",
+    "-t",
+    help="""The target label that will be analysed.  It must match the target
+    label that was used to generate the saliency maps provided with option
+    ``--input-folder``.  Samples with all other labels are ignored.""",
+    required=True,
+    type=click.INT,
+    default=1,
+    cls=ResourceOption,
+)
 @click.option(
     "--output-json",
     "-o",
@@ -70,6 +81,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 def saliency_interpretability(
     datamodule,
     input_folder,
+    target_label,
     output_json,
     **_,
 ) -> None:
@@ -128,7 +140,7 @@ def saliency_interpretability(
     datamodule.prepare_data()
     datamodule.setup(stage="predict")
 
-    results = run(input_folder, datamodule)
+    results = run(input_folder, target_label, datamodule)
 
     with output_json.open("w") as f:
         logger.info(f"Saving output file to `{str(output_json)}`...")
-- 
GitLab