diff --git a/src/ptbench/scripts/saliency_interpretability.py b/src/ptbench/scripts/saliency_interpretability.py index 4d7599a87200075948e759341e3178b328670e8f..bca56565ad098b105af820b1529761849336c95b 100644 --- a/src/ptbench/scripts/saliency_interpretability.py +++ b/src/ptbench/scripts/saliency_interpretability.py @@ -52,6 +52,17 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") default="saliency-maps", cls=ResourceOption, ) +@click.option( + "--target-label", + "-t", + help="""The target label that will be analysed. It must match the target + label that was used to generate the saliency maps provided with option + ``--input-folder``. Samples with all other labels are ignored.""", + required=True, + type=click.INT, + default=1, + cls=ResourceOption, +) @click.option( "--output-json", "-o", @@ -70,6 +81,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def saliency_interpretability( datamodule, input_folder, + target_label, output_json, **_, ) -> None: @@ -128,7 +140,7 @@ def saliency_interpretability( datamodule.prepare_data() datamodule.setup(stage="predict") - results = run(input_folder, datamodule) + results = run(input_folder, target_label, datamodule) with output_json.open("w") as f: logger.info(f"Saving output file to `{str(output_json)}`...")