From 2e149443f2fcabffc84382234315f698c2343161 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 6 Oct 2023 21:43:37 +0200
Subject: [PATCH] [saliencymap_generator] Save maps directly

---
 src/ptbench/engine/saliencymap_generator.py  | 79 +++++++++++---------
 src/ptbench/models/typing.py                 |  4 +-
 src/ptbench/scripts/generate_saliencymaps.py | 26 ++++---
 3 files changed, 61 insertions(+), 48 deletions(-)

diff --git a/src/ptbench/engine/saliencymap_generator.py b/src/ptbench/engine/saliencymap_generator.py
index 3596cf83..7605a561 100644
--- a/src/ptbench/engine/saliencymap_generator.py
+++ b/src/ptbench/engine/saliencymap_generator.py
@@ -12,14 +12,14 @@ import torch
 import torch.nn
 import tqdm
 
-from ..models.typing import VisualisationType
+from ..models.typing import SaliencyMapAlgorithm
 from .device import DeviceManager
 
 logger = logging.getLogger(__name__)
 
 
-def _create_visualiser(
-    vis_type: VisualisationType,
+def _create_saliency_map_callable(
+    algo_type: SaliencyMapAlgorithm,
     model: torch.nn.Module,
     target_layers: list[torch.nn.Module] | None,
     use_cuda: bool,
@@ -28,7 +28,7 @@ def _create_visualiser(
 
     import pytorch_grad_cam
 
-    match vis_type:
+    match algo_type:
         case "gradcam":
             return pytorch_grad_cam.GradCAM(
                 model=model, target_layers=target_layers, use_cuda=use_cuda
@@ -82,45 +82,45 @@ def _create_visualiser(
             )
         case _:
             raise ValueError(
-                f"Visualisation type `{vis_type}` is not currently supported."
+                f"Saliency map algorithm `{algo_type}` is not currently "
+                f"supported."
             )
 
 
-def _save_visualisation(
-    output_folder: pathlib.Path, name: str, vis: torch.Tensor
+def _save_saliency_map(
+    output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor
 ) -> None:
     """Helper function to save a saliency map to disk."""
 
     n = pathlib.Path(name)
     (output_folder / n.parent).mkdir(parents=True, exist_ok=True)
-    numpy.save(output_folder / n.with_suffix(".npy"), vis)
+    numpy.save(output_folder / n.with_suffix(".npy"), saliency_map[0])
 
 
 def run(
     model: lightning.pytorch.LightningModule,
     datamodule: lightning.pytorch.LightningDataModule,
     device_manager: DeviceManager,
-    visualisation_types: typing.Sequence[VisualisationType],
+    saliency_map_algorithms: typing.Sequence[SaliencyMapAlgorithm],
     target_class: typing.Literal["highest", "all"],
     positive_only: bool,
     output_folder: pathlib.Path,
 ) -> None:
-    """Applies visualisation techniques on input CXR, outputs pickled saliency
-    maps directly to disk.
+    """Applies saliency mapping techniques on input CXR, outputs pickled
+    saliency maps directly to disk.
 
     Parameters
     ---------
     model
         Neural network model (e.g. pasa).
     datamodule
-        The lightning datamodule to use for training **and** validation
+        The lightning datamodule to iterate on.
     device_manager
         An internal device representation, to be used for training and
         validation.  This representation can be converted into a pytorch device
         or a torch lightning accelerator setup.
-    visualisation_types
-        The types of visualisations to generate for the current model and
-        datamodule.
+    saliency_map_algorithms
+        The algorithms for saliency map estimation to use.
     target_class
         (Use only with multi-label models) Which class to target for CAM
         calculation. Can be either set to "all" or "highest". "highest" is
@@ -131,7 +131,7 @@ def run(
         label == 1 in a binary classification task).  This option is ignored on
         a multi-class output model.
     output_folder
-        Where to save all the visualisations (this path should exist before
+        Where to save all the saliency maps (this path should exist before
         this function is called)
     """
     from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
@@ -140,9 +140,10 @@ def run(
     from ..models.pasa import Pasa
 
     if isinstance(model, Pasa):
-        if "fullgrad" in visualisation_types:
+        if "fullgrad" in saliency_map_algorithms:
             raise ValueError(
-                "Fullgrad visualisation is not supported for the Pasa model."
+                "Fullgrad saliency map algorithm is not supported for the "
+                "Pasa model."
             )
         target_layers = [model.fc14]  # Last non-1x1 Conv2d layer
     elif isinstance(model, Densenet):
@@ -159,12 +160,17 @@ def run(
     model = model.to(device)
     model.eval()
 
-    for vis_type in visualisation_types:
-        cam = _create_visualiser(vis_type, model, target_layers, use_cuda)  # type: ignore
+    for algo_type in saliency_map_algorithms:
+        saliency_map_callable = _create_saliency_map_callable(
+            algo_type,
+            model,
+            target_layers,  # type: ignore
+            use_cuda,
+        )
 
         for k, v in datamodule.predict_dataloader().items():
             logger.info(
-                f"Generating saliency maps for dataset `{k}` via `{vis_type}`..."
+                f"Generating saliency maps for dataset `{k}` via `{algo_type}`..."
             )
 
             for sample in tqdm.tqdm(
@@ -180,37 +186,42 @@ def run(
                 if positive_only and (model.num_classes == 1) and (label == 0):
                     continue
 
-                # chooses target outputs to generate visualisations for
+                # chooses target outputs to generate saliency maps for
                 if model.num_classes > 1:
                     if target_class == "all":
-                        # just blindly generate visualisations for all outputs
+                        # just blindly generate saliency maps for all outputs
                         # - make one directory for every target output and lay
                         # images there like in the original dataset.
                         for output_num in range(model.num_classes):
                             use_folder = (
-                                output_folder / vis_type / str(output_num)
+                                output_folder / algo_type / str(output_num)
                             )
-                            vis = cam(
+                            saliency_map = saliency_map_callable(
                                 input_tensor=image,
                                 targets=[ClassifierOutputTarget(output_num)],  # type: ignore
                             )
-                            _save_visualisation(use_folder, name, vis)  # type: ignore
+                            _save_saliency_map(use_folder, name, saliency_map)  # type: ignore
 
                     else:
                         # pytorch-grad-cam will evaluate the output with the
-                        # highest value and produce a visualisation for it - we will
-                        # save it to disk.
-                        use_folder = output_folder / vis_type / "highest-output"
-                        vis = cam(input_tensor=image, targets=None)  # type: ignore
-                        _save_visualisation(use_folder, name, vis)  # type: ignore
+                        # highest value and produce a saliency map for it - we
+                        # will save it to disk.
+                        use_folder = (
+                            output_folder / algo_type / "highest-output"
+                        )
+                        saliency_map = saliency_map_callable(
+                            input_tensor=image,
+                            targets=None,  # type: ignore
+                        )
+                        _save_saliency_map(use_folder, name, saliency_map)  # type: ignore
                 else:
                     # binary classification model with a single output - just
                     # lay all cams uniformily like the original dataset
-                    use_folder = output_folder / vis_type
-                    vis = cam(
+                    use_folder = output_folder / algo_type
+                    saliency_map = saliency_map_callable(
                         input_tensor=image,
                         targets=[
                             ClassifierOutputTarget(0),  # type: ignore
                         ],
                     )
-                    _save_visualisation(use_folder, name, vis)  # type: ignore
+                    _save_saliency_map(use_folder, name, saliency_map)  # type: ignore
diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py
index 3c64a873..883811ac 100644
--- a/src/ptbench/models/typing.py
+++ b/src/ptbench/models/typing.py
@@ -26,7 +26,7 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[
 ]
 """A series of predictions for different database splits."""
 
-VisualisationType: typing.TypeAlias = typing.Literal[
+SaliencyMapAlgorithm: typing.TypeAlias = typing.Literal[
     "ablationcam",
     "eigencam",
     "eigengradcam",
@@ -41,4 +41,4 @@ VisualisationType: typing.TypeAlias = typing.Literal[
     "scorecam",
     "xgradcam",
 ]
-"""Supported visualisation types."""
+"""Supported saliency map algorithms."""
diff --git a/src/ptbench/scripts/generate_saliencymaps.py b/src/ptbench/scripts/generate_saliencymaps.py
index 7ba30ede..c330f5de 100644
--- a/src/ptbench/scripts/generate_saliencymaps.py
+++ b/src/ptbench/scripts/generate_saliencymaps.py
@@ -10,7 +10,7 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 
-from ..models.typing import VisualisationType
+from ..models.typing import SaliencyMapAlgorithm
 from .click import ConfigCommand
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@@ -27,7 +27,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
    .. code:: sh
 
-      ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/visualisations
+      ptbench generate-saliencymaps -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model_final.pth --output-folder=path/to/output
 
 """,
 )
@@ -53,7 +53,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 @click.option(
     "--output-folder",
     "-o",
-    help="Path where to store the visualisations (created if does not exist)",
+    help="Path where to store saliency maps (created if does not exist)",
     required=True,
     type=click.Path(
         exists=False,
@@ -62,7 +62,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
         writable=True,
         path_type=pathlib.Path,
     ),
-    default="visualisations",
+    default="saliency-maps",
     cls=ResourceOption,
 )
 @click.option(
@@ -106,11 +106,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--visualisation-types",
-    "-vt",
-    help="""visualisation techniques to be used. Can be called multiple times
+    "--saliency-map-algorithm",
+    "-s",
+    help="""Saliency map algorithm(s) to be used. Can be called multiple times
     with different techniques.""",
-    type=click.Choice(typing.get_args(VisualisationType), case_sensitive=False),
+    type=click.Choice(
+        typing.get_args(SaliencyMapAlgorithm), case_sensitive=False
+    ),
     multiple=True,
     default=["gradcam"],
     show_default=True,
@@ -149,7 +151,7 @@ def generate_saliencymaps(
     cache_samples,
     weight,
     parallel,
-    visualisation_types,
+    saliency_map_algorithm,
     target_class,
     positive_only,
     **_,
@@ -157,8 +159,8 @@ def generate_saliencymaps(
     """Generates saliency maps for locations on input images that affected the
     prediction.
 
-    The quality of saliency information depends on visualisation
-    technique and trained model.
+    The quality of saliency information depends on the saliency map
+    algorithm and trained model.
     """
 
     from ..engine.device import DeviceManager
@@ -189,7 +191,7 @@ def generate_saliencymaps(
         model=model,
         datamodule=datamodule,
         device_manager=device_manager,
-        visualisation_types=visualisation_types,
+        saliency_map_algorithms=saliency_map_algorithm,
         target_class=target_class,
         positive_only=positive_only,
         output_folder=output_folder,
-- 
GitLab