From 2e149443f2fcabffc84382234315f698c2343161 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 6 Oct 2023 21:43:37 +0200 Subject: [PATCH] [saliencymap_generator] Save maps directly --- src/ptbench/engine/saliencymap_generator.py | 79 +++++++++++--------- src/ptbench/models/typing.py | 4 +- src/ptbench/scripts/generate_saliencymaps.py | 26 ++++--- 3 files changed, 61 insertions(+), 48 deletions(-) diff --git a/src/ptbench/engine/saliencymap_generator.py b/src/ptbench/engine/saliencymap_generator.py index 3596cf83..7605a561 100644 --- a/src/ptbench/engine/saliencymap_generator.py +++ b/src/ptbench/engine/saliencymap_generator.py @@ -12,14 +12,14 @@ import torch import torch.nn import tqdm -from ..models.typing import VisualisationType +from ..models.typing import SaliencyMapAlgorithm from .device import DeviceManager logger = logging.getLogger(__name__) -def _create_visualiser( - vis_type: VisualisationType, +def _create_saliency_map_callable( + algo_type: SaliencyMapAlgorithm, model: torch.nn.Module, target_layers: list[torch.nn.Module] | None, use_cuda: bool, @@ -28,7 +28,7 @@ def _create_visualiser( import pytorch_grad_cam - match vis_type: + match algo_type: case "gradcam": return pytorch_grad_cam.GradCAM( model=model, target_layers=target_layers, use_cuda=use_cuda @@ -82,45 +82,45 @@ def _create_visualiser( ) case _: raise ValueError( - f"Visualisation type `{vis_type}` is not currently supported." + f"Saliency map algorithm `{algo_type}` is not currently " + f"supported." ) -def _save_visualisation( - output_folder: pathlib.Path, name: str, vis: torch.Tensor +def _save_saliency_map( + output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor ) -> None: """Helper function to save a saliency map to disk.""" n = pathlib.Path(name) (output_folder / n.parent).mkdir(parents=True, exist_ok=True) - numpy.save(output_folder / n.with_suffix(".npy"), vis) + numpy.save(output_folder / n.with_suffix(".npy"), saliency_map[0]) def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, device_manager: DeviceManager, - visualisation_types: typing.Sequence[VisualisationType], + saliency_map_algorithms: typing.Sequence[SaliencyMapAlgorithm], target_class: typing.Literal["highest", "all"], positive_only: bool, output_folder: pathlib.Path, ) -> None: - """Applies visualisation techniques on input CXR, outputs pickled saliency - maps directly to disk. + """Applies saliency mapping techniques on input CXR, outputs pickled + saliency maps directly to disk. Parameters --------- model Neural network model (e.g. pasa). datamodule - The lightning datamodule to use for training **and** validation + The lightning datamodule to iterate on. device_manager An internal device representation, to be used for training and validation. This representation can be converted into a pytorch device or a torch lightning accelerator setup. - visualisation_types - The types of visualisations to generate for the current model and - datamodule. + saliency_map_algorithms + The algorithms for saliency map estimation to use. target_class (Use only with multi-label models) Which class to target for CAM calculation. Can be either set to "all" or "highest". "highest" is @@ -131,7 +131,7 @@ def run( label == 1 in a binary classification task). This option is ignored on a multi-class output model. output_folder - Where to save all the visualisations (this path should exist before + Where to save all the saliency maps (this path should exist before this function is called) """ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget @@ -140,9 +140,10 @@ def run( from ..models.pasa import Pasa if isinstance(model, Pasa): - if "fullgrad" in visualisation_types: + if "fullgrad" in saliency_map_algorithms: raise ValueError( - "Fullgrad visualisation is not supported for the Pasa model." + "Fullgrad saliency map algorithm is not supported for the " + "Pasa model." ) target_layers = [model.fc14] # Last non-1x1 Conv2d layer elif isinstance(model, Densenet): @@ -159,12 +160,17 @@ def run( model = model.to(device) model.eval() - for vis_type in visualisation_types: - cam = _create_visualiser(vis_type, model, target_layers, use_cuda) # type: ignore + for algo_type in saliency_map_algorithms: + saliency_map_callable = _create_saliency_map_callable( + algo_type, + model, + target_layers, # type: ignore + use_cuda, + ) for k, v in datamodule.predict_dataloader().items(): logger.info( - f"Generating saliency maps for dataset `{k}` via `{vis_type}`..." + f"Generating saliency maps for dataset `{k}` via `{algo_type}`..." ) for sample in tqdm.tqdm( @@ -180,37 +186,42 @@ def run( if positive_only and (model.num_classes == 1) and (label == 0): continue - # chooses target outputs to generate visualisations for + # chooses target outputs to generate saliency maps for if model.num_classes > 1: if target_class == "all": - # just blindly generate visualisations for all outputs + # just blindly generate saliency maps for all outputs # - make one directory for every target output and lay # images there like in the original dataset. for output_num in range(model.num_classes): use_folder = ( - output_folder / vis_type / str(output_num) + output_folder / algo_type / str(output_num) ) - vis = cam( + saliency_map = saliency_map_callable( input_tensor=image, targets=[ClassifierOutputTarget(output_num)], # type: ignore ) - _save_visualisation(use_folder, name, vis) # type: ignore + _save_saliency_map(use_folder, name, saliency_map) # type: ignore else: # pytorch-grad-cam will evaluate the output with the - # highest value and produce a visualisation for it - we will - # save it to disk. - use_folder = output_folder / vis_type / "highest-output" - vis = cam(input_tensor=image, targets=None) # type: ignore - _save_visualisation(use_folder, name, vis) # type: ignore + # highest value and produce a saliency map for it - we + # will save it to disk. + use_folder = ( + output_folder / algo_type / "highest-output" + ) + saliency_map = saliency_map_callable( + input_tensor=image, + targets=None, # type: ignore + ) + _save_saliency_map(use_folder, name, saliency_map) # type: ignore else: # binary classification model with a single output - just # lay all cams uniformily like the original dataset - use_folder = output_folder / vis_type - vis = cam( + use_folder = output_folder / algo_type + saliency_map = saliency_map_callable( input_tensor=image, targets=[ ClassifierOutputTarget(0), # type: ignore ], ) - _save_visualisation(use_folder, name, vis) # type: ignore + _save_saliency_map(use_folder, name, saliency_map) # type: ignore diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py index 3c64a873..883811ac 100644 --- a/src/ptbench/models/typing.py +++ b/src/ptbench/models/typing.py @@ -26,7 +26,7 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ ] """A series of predictions for different database splits.""" -VisualisationType: typing.TypeAlias = typing.Literal[ +SaliencyMapAlgorithm: typing.TypeAlias = typing.Literal[ "ablationcam", "eigencam", "eigengradcam", @@ -41,4 +41,4 @@ VisualisationType: typing.TypeAlias = typing.Literal[ "scorecam", "xgradcam", ] -"""Supported visualisation types.""" +"""Supported saliency map algorithms.""" diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/generate_saliencymaps.py index 7ba30ede..c330f5de 100644 --- a/src/ptbench/scripts/generate_saliencymaps.py +++ b/src/ptbench/scripts/generate_saliencymaps.py @@ -10,7 +10,7 @@ import click from clapper.click import ResourceOption, verbosity_option from clapper.logging import setup -from ..models.typing import VisualisationType +from ..models.typing import SaliencyMapAlgorithm from .click import ConfigCommand logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/visualisations + ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/output """, ) @@ -53,7 +53,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @click.option( "--output-folder", "-o", - help="Path where to store the visualisations (created if does not exist)", + help="Path where to store saliency maps (created if does not exist)", required=True, type=click.Path( exists=False, @@ -62,7 +62,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") writable=True, path_type=pathlib.Path, ), - default="visualisations", + default="saliency-maps", cls=ResourceOption, ) @click.option( @@ -106,11 +106,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--visualisation-types", - "-vt", - help="""visualisation techniques to be used. Can be called multiple times + "--saliency-map-algorithm", + "-s", + help="""Saliency map algorithm(s) to be used. Can be called multiple times with different techniques.""", - type=click.Choice(typing.get_args(VisualisationType), case_sensitive=False), + type=click.Choice( + typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + ), multiple=True, default=["gradcam"], show_default=True, @@ -149,7 +151,7 @@ def generate_saliencymaps( cache_samples, weight, parallel, - visualisation_types, + saliency_map_algorithm, target_class, positive_only, **_, @@ -157,8 +159,8 @@ def generate_saliencymaps( """Generates saliency maps for locations on input images that affected the prediction. - The quality of saliency information depends on visualisation - technique and trained model. + The quality of saliency information depends on the saliency map + algorithm and trained model. """ from ..engine.device import DeviceManager @@ -189,7 +191,7 @@ def generate_saliencymaps( model=model, datamodule=datamodule, device_manager=device_manager, - visualisation_types=visualisation_types, + saliency_map_algorithms=saliency_map_algorithm, target_class=target_class, positive_only=positive_only, output_folder=output_folder, -- GitLab