# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import functools import logging import multiprocessing import typing import lightning.pytorch import numpy as np import torch import tqdm from pytorch_grad_cam.metrics.road import ( ROADLeastRelevantFirstAverage, ROADMostRelevantFirstAverage, ) from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from ...data.typing import Sample from ...models.typing import SaliencyMapAlgorithm from ..device import DeviceManager logger = logging.getLogger(__name__) class SigmoidClassifierOutputTarget(torch.nn.Module): def __init__(self, category): self.category = category def __call__(self, model_output): sigmoid_output = torch.sigmoid(model_output) if len(sigmoid_output.shape) == 1: return sigmoid_output[self.category] return sigmoid_output[:, self.category] def _calculate_road_scores( model: lightning.pytorch.LightningModule, images: torch.Tensor, output_num: int, saliency_map_callable: typing.Callable, percentiles: typing.Sequence[int], ) -> tuple[float, float, float]: """Calculates average ROAD scores for different removal percentiles. This function calculates ROAD scores by averaging the scores for different removal (hardcoded) percentiles, for a single input image, a given visualization method, a target class. Parameters ---------- model Neural network model (e.g. pasa). images A batch of input images to use evaluating the ROAD scores. Currently, we only support batches with a single image. output_num Target output neuron to take into consideration when evaluating the saliency maps and calculating ROAD scores saliency_map_callable A callable saliency-map generator from grad-cam percentiles A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores. Returns ------- A 3-tuple containing floating point numbers representing the most-relevant-first average score (``morf``), least-relevant-first average score (``lerf``) and the combined value (``(lerf-morf)/2``). """ saliency_map = saliency_map_callable( input_tensor=images, targets=[ClassifierOutputTarget(output_num)] ) cam_metric_ROADMoRF_avg = ROADMostRelevantFirstAverage( percentiles=percentiles ) cam_metric_ROADLeRF_avg = ROADLeastRelevantFirstAverage( percentiles=percentiles ) # Calculate ROAD scores for all percentiles and average - this is NOT the # current processing bottleneck. If you want to optimise anyting, look at # the evaluation of the perturbation using scipy.sparse at the # NoisyLinearImputer, part of the grad-cam package (submodule # ``metrics.road``. metric_target = [SigmoidClassifierOutputTarget(output_num)] MoRF_scores = cam_metric_ROADMoRF_avg( input_tensor=images, cams=saliency_map, model=model, targets=metric_target, ) LeRF_scores = cam_metric_ROADLeRF_avg( input_tensor=images, cams=saliency_map, model=model, targets=metric_target, ) return ( float(MoRF_scores.item()), float(LeRF_scores.item()), float(LeRF_scores.item() - MoRF_scores.item()) / 2.0, ) def _process_sample( sample: Sample, model: lightning.pytorch.LightningModule, device: torch.device, saliency_map_callable: typing.Callable, target_class: typing.Literal["highest", "all"], positive_only: bool, percentiles: typing.Sequence[int], ) -> list: """Helper function to :py:func:`run` to be used in multiprocessing contexts. Parameters ---------- model Neural network model (e.g. pasa). device The device to process samples on. saliency_map_callable A callable saliency-map generator from grad-cam target_class Class to target for saliency estimation. Can be either set to "all" or "highest". "highest". positive only If set, and the model chosen has a single output (binary), then saliency maps will only be generated for samples of the positive class percentiles A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores. """ name: str = sample[1]["name"][0] label: int = int(sample[1]["label"].item()) image = sample[0].to(device=device, non_blocking=torch.cuda.is_available()) # in binary classification systems, negative labels may be skipped if positive_only and (model.num_classes == 1) and (label == 0): return [name, label] # chooses target outputs to generate saliency maps for if model.num_classes > 1: # type: ignore if target_class == "all": # test all outputs for output_num in range(model.num_classes): # type: ignore results = _calculate_road_scores( model, image, output_num, saliency_map_callable, percentiles, ) return [name, label, output_num, *results] else: # we will figure out the output with the highest value and # evaluate the saliency mapping technique over it. outputs = saliency_map_callable.activations_and_grads(image) # type: ignore output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1) assert len(output_nums) == 1 results = _calculate_road_scores( model, image, output_nums[0], saliency_map_callable, percentiles, ) return [name, label, output_nums[0], *results] # default route for binary classification results = _calculate_road_scores( model, image, 0, saliency_map_callable, percentiles, ) return [name, label, 0, *results] def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, saliency_map_algorithm: SaliencyMapAlgorithm, target_class: typing.Literal["highest", "all"], positive_only: bool, percentiles: typing.Sequence[int], parallel: int, ) -> dict[str, list[typing.Any]]: """Evaluates ROAD scores for all samples in a datamodule. The ROAD algorithm was first described at [ROAD-2022]_. It estimates explainability (in the completeness sense) of saliency maps by substituting relevant pixels in the input image by a local average, and re-running prediction on the altered image, and measuring changes in the output classification score when said perturbations are in place. By substituting most or least relevant pixels with surrounding averages, the ROAD algorithm estimates the importance of such elements in the produced saliency map. As 2023, this measurement technique is considered to be one of the state-of-the-art metrics of explainability. This function returns a dictionary containing most-relevant-first (remove a percentile of the most relevant pixels), least-relevant-first (remove a percentile of the least relevant pixels), and combined ROAD evaluations per sample for a particular saliency mapping algorithm. Parameters --------- model Neural network model (e.g. pasa). datamodule The lightning datamodule to iterate on. device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. saliency_map_algorithm The algorithm for saliency map estimation to use. target_class (Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is default, which means only saliency maps for the class with the highest activation will be generated. positive_only If set, saliency maps will only be generated for positive samples (ie. label == 1 in a binary classification task). This option is ignored on a multi-class output model. percentiles A sequence of percentiles (percent x100) integer values indicating the proportion of pixels to perturb in the original image to calculate both MoRF and LeRF scores. parallel Use multiprocessing for data processing: if set to -1, disables multiprocessing. Set to 0 to enable as many data processing instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data processing. Returns ------- A dictionary where keys are dataset names in the provide datamodule, and values are lists containing sample information alongside metrics calculated: * Sample name * Sample target class * The model output number used for the ROAD analysis (0, for binary classifers as there is typically only one output). * ``morf``: ROAD most-relevant-first average of percentiles 20, 40, 60 and 80 (a.k.a. AOPC-MoRF). * ``lerf``: ROAD least-relevant-first average of percentiles 20, 40, 60 and 80 (a.k.a. AOPC-LeRF). * combined: Average ROAD combined score by evaluating ``(lerf-morf)/2`` (a.k.a. AOPC-Combined). """ from ...models.densenet import Densenet from ...models.pasa import Pasa from .generator import _create_saliency_map_callable if isinstance(model, Pasa): if saliency_map_algorithm == "fullgrad": raise ValueError( "Fullgrad saliency map algorithm is not supported for the " "Pasa model." ) target_layers = [model.fc14] # Last non-1x1 Conv2d layer elif isinstance(model, Densenet): target_layers = [ model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore ] else: raise TypeError(f"Model of type `{type(model)}` is not yet supported.") use_cuda = device_manager.device_type == "cuda" if device_manager.device_type in ("cuda", "mps") and ( parallel == 0 or parallel > 1 ): raise RuntimeError( f"The number of multiprocessing instances is set to {parallel} and " f"you asked to use a GPU (device = `{device_manager.device_type}`" f"). The currently implementation can only handle a single GPU. " f"Either disable GPU utilisation or set the number of " f"multiprocessing instances to one, or disable multiprocessing " "entirely (ie. set it to -1)." ) # prepares model for evaluation, cast to target device device = device_manager.torch_device() model = model.to(device) model.eval() saliency_map_callable = _create_saliency_map_callable( saliency_map_algorithm, model, target_layers, # type: ignore use_cuda, ) retval: dict[str, list[typing.Any]] = {} # our worker function _process = functools.partial( _process_sample, model=model, device=device, saliency_map_callable=saliency_map_callable, target_class=target_class, positive_only=positive_only, percentiles=percentiles, ) for k, v in datamodule.predict_dataloader().items(): retval[k] = [] if parallel < 0: logger.info( f"Computing ROAD scores for dataset `{k}` in the current " f"process context..." ) for sample in tqdm.tqdm( v, desc="samples", leave=False, disable=None ): retval[k].append(_process(sample)) else: instances = parallel or multiprocessing.cpu_count() logger.info( f"Computing ROAD scores for dataset `{k}` using {instances} " f"processes..." ) with multiprocessing.Pool(instances) as p: retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v))) return retval