# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import pathlib import typing import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup from ..models.typing import SaliencyMapAlgorithm from .click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.command( entry_point_group="ptbench.config", cls=ConfigCommand, epilog="""Examples: 1. Generates saliency maps for all prediction dataloaders on a datamodule, using a pre-trained DenseNet model, and saves them as numpy-pickeled objects on the output directory: .. code:: sh ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.ckpt --output-folder=path/to/output """, ) @click.option( "--model", "-m", help="""A lightining module instance implementing the network architecture (not the weights, necessarily) to be used for inference. Currently, only supports pasa and densenet models.""", required=True, cls=ResourceOption, ) @click.option( "--datamodule", "-d", help="""A lighting data module that will be asked for prediction data loaders. Typically, this includes all configured splits in a datamodule, however this is not a requirement. A datamodule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, cls=ResourceOption, ) @click.option( "--output-folder", "-o", help="Path where to store saliency maps (created if does not exist)", required=True, type=click.Path( exists=False, file_okay=False, dir_okay=True, writable=True, path_type=pathlib.Path, ), default="saliency-maps", cls=ResourceOption, ) @click.option( "--device", "-x", help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", cls=ResourceOption, ) @click.option( "--cache-samples/--no-cache-samples", help="If set to True, loads the sample into memory, " "otherwise loads them at runtime.", required=True, show_default=True, default=False, cls=ResourceOption, ) @click.option( "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), corresponding to the architecture set with `--model`.""", required=True, cls=ResourceOption, type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), ) @click.option( "--parallel", "-P", help="""Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading.""", type=click.IntRange(min=-1), show_default=True, required=True, default=-1, cls=ResourceOption, ) @click.option( "--saliency-map-algorithm", "-s", help="""Saliency map algorithm(s) to be used. Can be called multiple times with different techniques.""", type=click.Choice( typing.get_args(SaliencyMapAlgorithm), case_sensitive=False ), multiple=True, default=["gradcam"], show_default=True, cls=ResourceOption, ) @click.option( "--target-class", "-C", help="""This option should only be used with multiclass models. It defines the class to target for saliency estimation. Can be either set to "all" or "highest". "highest" (the default), means only saliency maps for the class with the highest activation will be generated.""", required=False, type=click.Choice( ["highest", "all"], case_sensitive=False, ), default="highest", cls=ResourceOption, ) @click.option( "--positive-only/--no-positive-only", "-z/-Z", help="""If set, and the model chosen has a single output (binary), then saliency maps will only be generated for samples of the positive class. This option has no effect for multiclass models.""", default=False, cls=ResourceOption, ) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def generate_saliencymaps( model, datamodule, output_folder, device, cache_samples, weight, parallel, saliency_map_algorithm, target_class, positive_only, **_, ) -> None: """Generates saliency maps for locations on input images that affected the prediction. The quality of saliency information depends on the saliency map algorithm and trained model. """ from ..engine.device import DeviceManager from ..engine.saliency.generator import run logger.info(f"Output folder: {output_folder}") output_folder.mkdir(parents=True, exist_ok=True) device_manager = DeviceManager(device) # batch_size must be == 1 for now (underlying code is NOT prepared to # treat multiple samples at once). datamodule.set_chunk_size(1, 1) datamodule.cache_samples = cache_samples datamodule.parallel = parallel datamodule.model_transforms = model.model_transforms datamodule.prepare_data() datamodule.setup(stage="predict") logger.info(f"Loading checkpoint from `{weight}`...") model = model.load_from_checkpoint(weight, strict=False) run( model=model, datamodule=datamodule, device_manager=device_manager, saliency_map_algorithms=saliency_map_algorithm, target_class=target_class, positive_only=positive_only, output_folder=output_folder, )