# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import pathlib import typing import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from ...models.typing import SaliencyMapAlgorithm from ..click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, epilog="""Examples: 1. Calculates the ROAD scores for an existing dataset configuration and stores them in .csv files: .. code:: sh ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/completeness-scores/ """, ) @click.option( "--model", "-m", help="""A lightining module instance implementing the network architecture (not the weights, necessarily) to be used for inference. Currently, only supports pasa and densenet models.""", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", help="""A lighting data module that will be asked for prediction data loaders. Typically, this includes all configured splits in a datamodule, however this is not a requirement. A datamodule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, ) @click.option( "--output-folder", "-o", help="Path where to store saliency maps (created if does not exist)", required=True, type=click.Path( exists=False, file_okay=False, dir_okay=True, writable=True, path_type=pathlib.Path, ), default="saliency-maps", cls=ResourceOption, ) @click.option( "--device", "-x", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", cls=ResourceOption, ) @click.option( "--cache-samples/--no-cache-samples", help="If set to True, loads the sample into memory, " "otherwise loads them at runtime.", required=True, show_default=True, default=False, cls=ResourceOption, ) @click.option( "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), corresponding to the architecture set with `--model`.""", required=True, cls=ResourceOption, type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), ) @click.option( "--parallel", "-P", help="""Use multiprocessing for data loading processing: if set to -1 (default), disables multiprocessing. Set to 0 to enable as many data processing instances as processing cores available in the system. Set to >= 1 to enable that many multiprocessing instances. Note that if you activate this option, then you must use --device=cpu, as using a GPU concurrently is not supported.""", type=click.IntRange(min=-1), show_default=True, required=True, default=-1, cls=ResourceOption, ) @click.option( "--saliency-map-algorithm", "-s", help="""Saliency map algorithm(s) to be used. Can be called multiple times with different techniques.""", type=click.Choice( typing.get_args(SaliencyMapAlgorithm), case_sensitive=False ), multiple=True, default=["gradcam"], show_default=True, cls=ResourceOption, ) @click.option( "--target-class", "-C", help="""This option should only be used with multiclass models. It defines the class to target for saliency estimation. Can be either set to "all" or "highest". "highest" (the default), means only saliency maps for the class with the highest activation will be generated.""", required=False, type=click.Choice( ["highest", "all"], case_sensitive=False, ), default="highest", cls=ResourceOption, ) @click.option( "--positive-only/--no-positive-only", "-z/-Z", help="""If set, and the model chosen has a single output (binary), then saliency maps will only be generated for samples of the positive class. This option has no effect for multiclass models.""", default=False, cls=ResourceOption, ) @click.option( "--percentile", "-e", help="""One or more percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores.""", multiple=True, default=[20, 40, 60, 80], show_default=True, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def completeness( model, datamodule, output_folder, device, cache_samples, weight, parallel, saliency_map_algorithm, target_class, positive_only, percentile, **_, ) -> None: """Evaluates saliency map algorithm completeness using RemOve And Debias (ROAD). For each selected saliency map algorithm, evaluates the completeness of explanations using the RemOve And Debias (ROAD) algorithm. The ROAD algorithm was first described at [ROAD-2022]_. It estimates explainability (in the completeness sense) of saliency mapping algorithms by substituting relevant pixels in the input image by a local average, re-running prediction on the altered image, and measuring changes in the output classification score when said perturbations are in place. By substituting most or least relevant pixels with surrounding averages, the ROAD algorithm estimates the importance of such elements in the produced saliency map. As 2023, this measurement technique is considered to be one of the state-of-the-art metrics of explainability. This program outputs a JSON file containing the ROAD evaluations (using most-relevant-first, or MoRF, and least-relevant-first, or LeRF for each sample in the datamodule, and per saliency-mapping algorithm. Each saliency-mapping algorithm yields a single JSON file with the target algorithm name on the ``output-folder``. Values for MoRF and LeRF represent averages by removing 20, 40, 60 and 80% of most or least relevant pixels respectively from the image, and averaging results for all these percentiles. .. note:: This application is relatively slow when processing a large datamodule with many (positive) samples. """ import json from ...engine.device import DeviceManager from ...engine.saliency.completeness import run logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) if device in ("cuda", "mps") and (parallel == 0 or parallel > 1): raise RuntimeError( f"The number of multiprocessing instances is set to {parallel} and " f"you asked to use a GPU (device = `{device}`). The currently " f"implementation can only handle a single GPU. Either disable GPU " f"utilisation or set the number of multiprocessing instances to " f"one, or disable multiprocessing entirely (ie. set it to -1)." ) device_manager = DeviceManager(device) # batch_size must be == 1 for now (underlying code is NOT prepared to # treat multiple samples at once). datamodule.set_chunk_size(1, 1) datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="predict") logger.info(f"Loading checkpoint from `{weight}`...") model = model.load_from_checkpoint(weight, strict=False) for algo in saliency_map_algorithm: logger.info( f"Evaluating RemOve And Debias (ROAD) average scores for " f"algorithm `{algo}` with percentiles " f"`{', '.join([str(k) for k in percentile])}`..." ) results = run( model=model, datamodule=datamodule, device_manager=device_manager, saliency_map_algorithm=algo, target_class=target_class, positive_only=positive_only, percentiles=percentile, parallel=parallel, ) output_json = output_folder / (algo + ".json") with output_json.open("w") as f: logger.info(f"Saving output file to `{str(output_json)}`...") json.dump(results, f, indent=2)