From 910e58c16e4b2c40890c2e99996a93a2a068ba32 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 15 Dec 2023 14:05:45 +0100
Subject: [PATCH] [scripts.view_saliency] New viewer implementation based on
 pillow and matplotlib instead of opencv

---
 src/ptbench/engine/saliency/viewer.py | 263 ++++++++++++++++++++++++++
 src/ptbench/scripts/cli.py            |   6 +-
 src/ptbench/scripts/view_saliency.py  | 137 ++++++++++++++
 3 files changed, 402 insertions(+), 4 deletions(-)
 create mode 100644 src/ptbench/engine/saliency/viewer.py
 create mode 100644 src/ptbench/scripts/view_saliency.py

diff --git a/src/ptbench/engine/saliency/viewer.py b/src/ptbench/engine/saliency/viewer.py
new file mode 100644
index 00000000..43011a38
--- /dev/null
+++ b/src/ptbench/engine/saliency/viewer.py
@@ -0,0 +1,263 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import os
+import pathlib
+import typing
+
+import lightning.pytorch
+import matplotlib.pyplot
+import numpy
+import numpy.typing
+import PIL.Image
+import PIL.ImageColor
+import PIL.ImageDraw
+import torchvision.transforms.functional
+
+from tqdm import tqdm
+
+from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
+
+logger = logging.getLogger(__name__)
+
+
+def _overlay_saliency_map(
+    image: PIL.Image.Image,
+    saliencies: numpy.typing.NDArray[numpy.double],
+    colormap: typing.Literal[  # we accept any "Sequential" colormap from mpl
+        "viridis",
+        "plasma",
+        "inferno",
+        "magma",
+        "cividis",
+        "Greys",
+        "Purples",
+        "Blues",
+        "Greens",
+        "Oranges",
+        "Reds",
+        "YlOrBr",
+        "YlOrRd",
+        "OrRd",
+        "PuRd",
+        "RdPu",
+        "BuPu",
+        "GnBu",
+        "PuBu",
+        "YlGnBu",
+        "PuBuGn",
+        "BuGn",
+        "YlGn",
+    ],
+    image_weight: float,
+) -> PIL.Image.Image:
+    """Creates an overlayed represention of the saliency map on the original
+    image.
+
+    This is a slightly modified version of the show_cam_on_image implementation in:
+    https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead
+    of opencv.
+
+
+    Parameters
+    ----------
+    image
+        The input imge that will be overlayed with the saliency map
+    saliencies
+        The saliency map that will be overlaid on the (raw) image
+    colormap
+        The name of the (matplotlib) colormap to be used
+    image_weight
+        The final result is ``image_weight * image + (1-image_weight) *
+        saliency_map``.
+
+
+    Returns
+    -------
+        A modified version of the input ``image`` with the overlaid saliency
+        map.
+    """
+
+    image_array = numpy.array(image, dtype=numpy.float32) / 255.0
+
+    assert image_array.shape[:2] == saliencies.shape, (
+        f"The shape of the saliency map ({saliencies.shape}) is different "
+        f"from the shape of the input image ({image_array.shape[:2]})."
+    )
+
+    assert (
+        saliencies.max() <= 1
+    ), f"The input saliency map should be in the range [0, 1] (max={saliencies.max()})"
+
+    assert (
+        image_weight > 0 and image_weight < 1
+    ), f"image_weight should be in the range [0, 1], but got {image_weight}"
+
+    heatmap = matplotlib.pyplot.cm.get_cmap(colormap)(saliencies)
+
+    # For pixels where the mask is zero, the original image pixels are being
+    # used without a mask.
+    result = numpy.where(
+        saliencies[..., numpy.newaxis] == 0,
+        image_array,
+        (image_weight * image_array) + ((1 - image_weight) * heatmap),
+    )
+
+    return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")
+
+
+def _overlay_bounding_box(
+    image: PIL.Image.Image,
+    bbox: BoundingBox,
+    color: str,
+    width: int,
+) -> PIL.Image.Image:
+    """Draws ground-truth on the input image.
+
+    Parameters
+    ----------
+    image
+        The input imge that will be overlayed with the saliency map
+    bbox
+        The bounding box to draw on the input image
+    color
+        The color to use for drawing the bounding box. Any of the colours in
+        :any:`PIL.ImageColor.colormap` are accepted.
+    width
+        The width of the bounding box, in pixels.  A larger value creates a
+        bounding box that is thicker, towards the outside of the boxed area.
+
+
+    Returns
+    -------
+        A modified version of the input ``image`` with the ground-truth drawn
+        on the top.
+    """
+
+    draw = PIL.ImageDraw.Draw(image)
+    draw.rectangle(
+        (bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax),
+        outline=PIL.ImageColor.getrgb(color),
+        width=width,
+    )
+    return image
+
+
+def _process_sample(
+    raw_data: numpy.typing.NDArray[numpy.double],
+    saliencies: numpy.typing.NDArray[numpy.double],
+    ground_truth: BoundingBoxes,
+) -> PIL.Image.Image:
+    """Generates an overlayed representation of the original sample and
+    saliency maps.
+
+    Parameters
+    ----------
+    raw_data
+        The raw data representing the input sample that will be overlayed with
+        saliency maps and annotations
+    saliencies
+        The saliency map recovered from the model, that will be inprinted on
+        the raw_data
+    ground_truth
+        Ground-truth annotations that may be inprinted on the final image
+
+
+    Returns
+    -------
+        An image with the original raw data overlayed with the different
+        elements as selected by the user.
+    """
+
+    # we need a colour image to eventually overlay a (coloured) saliency map on
+    # the top, draw rectangles and other annotations in coulour.  So, we force
+    # it right up front.
+    retval = torchvision.transforms.functional.to_pil_image(raw_data).convert(
+        "RGB"
+    )
+
+    retval = _overlay_saliency_map(
+        retval, saliencies, colormap="plasma", image_weight=0.5
+    )
+
+    for k in ground_truth:
+        retval = _overlay_bounding_box(retval, k, color="green", width=2)
+
+    return retval
+
+
+def run(
+    datamodule: lightning.pytorch.LightningDataModule,
+    input_folder: pathlib.Path,
+    target_label: int,
+    output_folder: pathlib.Path,
+    show_groundtruth: bool,
+    threshold: float,
+):
+    """Overlays saliency maps on CXR to output final images with heatmaps.
+
+    Parameters
+    ----------
+    datamodule
+        The lightning datamodule to iterate on.
+    input_folder
+        Directory in which the saliency maps are stored for a specific
+        visualization type.
+    target_label
+        The label to target for evaluating interpretability metrics. Samples
+        contining any other label are ignored.
+    output_folder
+        Directory in which the resulting visualisations will be saved.
+    show_groundtruth
+        If set, inprint ground truth labels over the original image and
+        saliency maps.
+    threshold : float
+        The pixel values above ``threshold``% of max value are kept in the
+        original saliency map.  Everything else is set to zero.  The value
+        proposed on [SCORECAM-2020]_ is 0.2.  Use this value if unsure.
+    """
+
+    for dataset_name, dataset_loader in datamodule.predict_dataloader().items():
+        logger.info(
+            f"Generating visualisations for samples at dataset `{dataset_name}`..."
+        )
+
+        for sample in tqdm(
+            dataset_loader, desc="batches", leave=False, disable=None
+        ):
+            name = str(sample[1]["name"][0])
+            label = int(sample[1]["label"].item())
+            data = sample[0][0]
+
+            if label != target_label:
+                # no visualisation was generated
+                continue
+
+            saliencies = numpy.load(
+                input_folder / pathlib.Path(name).with_suffix(".npy")
+            )
+            saliencies[saliencies < (threshold * saliencies.max())] = 0
+
+            # TODO: This is very specific to the TBX11k system for labelling
+            # regions of interest.  We need to abstract from this to support more
+            # datasets and other ways to annotate.
+            if show_groundtruth:
+                ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())
+            else:
+                ground_truth = BoundingBoxes()
+
+            # we fully process this entry
+            image = _process_sample(
+                data,
+                saliencies,
+                ground_truth,
+            )
+
+            # Save image
+            output_file_path = output_folder / pathlib.Path(name).with_suffix(
+                ".png"
+            )
+            os.makedirs(output_file_path.parent, exist_ok=True)
+            image.save(output_file_path)
diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py
index 0311bef5..cd7ff093 100644
--- a/src/ptbench/scripts/cli.py
+++ b/src/ptbench/scripts/cli.py
@@ -7,7 +7,6 @@ import click
 from clapper.click import AliasedGroup
 
 from . import (
-    compare_vis,
     config,
     database,
     evaluate,
@@ -19,7 +18,7 @@ from . import (
     saliency_interpretability,
     train,
     train_analysis,
-    visualize,
+    view_saliency,
 )
 
 
@@ -32,7 +31,6 @@ def cli():
     pass
 
 
-cli.add_command(compare_vis.compare_vis)
 cli.add_command(config.config)
 cli.add_command(database.database)
 cli.add_command(evaluate.evaluate)
@@ -44,4 +42,4 @@ cli.add_command(generate_saliencymaps.generate_saliencymaps)
 cli.add_command(predict.predict)
 cli.add_command(train.train)
 cli.add_command(train_analysis.train_analysis)
-cli.add_command(visualize.visualize)
+cli.add_command(view_saliency.view_saliency)
diff --git a/src/ptbench/scripts/view_saliency.py b/src/ptbench/scripts/view_saliency.py
new file mode 100644
index 00000000..c3b9379f
--- /dev/null
+++ b/src/ptbench/scripts/view_saliency.py
@@ -0,0 +1,137 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import os
+import pathlib
+
+import click
+
+from clapper.click import ConfigCommand, ResourceOption, verbosity_option
+from clapper.logging import setup
+
+logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
+
+
+@click.command(
+    entry_point_group="ptbench.config",
+    cls=ConfigCommand,
+    epilog="""Examples:
+
+\b
+    1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:
+
+       .. code:: sh
+
+          ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations
+
+""",
+)
+@click.option(
+    "--model",
+    "-m",
+    help="A lightining module instance implementing the network to be used for applying the necessary data transformations.",
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--datamodule",
+    "-d",
+    help="A lighting data module containing the training, validation and test sets.",
+    required=True,
+    cls=ResourceOption,
+)
+@click.option(
+    "--input-folder",
+    "-i",
+    help="Path to the folder containing the saliency maps for a specific visualization type.",
+    required=True,
+    type=click.Path(
+        file_okay=False,
+        dir_okay=True,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
+    default="visualizations",
+    cls=ResourceOption,
+)
+@click.option(
+    "--output-folder",
+    "-o",
+    help="Path where to store the ROAD scores (created if does not exist)",
+    required=True,
+    type=click.Path(
+        file_okay=False,
+        dir_okay=True,
+        writable=True,
+        path_type=pathlib.Path,
+    ),
+    default="visualizations",
+    cls=ResourceOption,
+)
+@click.option(
+    "--show-groundtruth/--no-show-groundtruth",
+    "-G/-g",
+    help="""If set, visualizations for ground truth labels will be generated.
+    Only works for datasets with bounding boxes.""",
+    is_flag=True,
+    default=False,
+    cls=ResourceOption,
+)
+@click.option(
+    "--threshold",
+    "-t",
+    help="""The pixel values above ``threshold``% of max value are kept in the
+    original saliency map.  Everything else is set to zero.  The value proposed
+    on [SCORECAM-2020]_ is 0.2.  Use this value if unsure.""",
+    show_default=True,
+    required=True,
+    default=0.2,
+    type=click.FloatRange(min=0, max=1),
+    cls=ResourceOption,
+)
+@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
+def view_saliency(
+    model,
+    datamodule,
+    input_folder,
+    output_folder,
+    show_groundtruth,
+    threshold,
+    **_,
+) -> None:
+    """Generates heatmaps for input CXRs based on existing saliency maps."""
+
+    from ..engine.saliency.viewer import run
+    from .utils import save_sh_command
+
+    assert (
+        input_folder != output_folder
+    ), "Output folder must not be the same as the input folder."
+
+    assert not str(output_folder).startswith(
+        str(input_folder)
+    ), "Output folder must not be a subdirectory of the input folder."
+
+    logger.info(f"Output folder: {output_folder}")
+    os.makedirs(output_folder, exist_ok=True)
+
+    save_sh_command(output_folder / "command.sh")
+
+    datamodule.set_chunk_size(1, 1)
+    datamodule.drop_incomplete_batch = False
+    # datamodule.cache_samples = cache_samples
+    # datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
+
+    datamodule.prepare_data()
+    datamodule.setup(stage="predict")
+
+    run(
+        datamodule=datamodule,
+        input_folder=input_folder,
+        target_label=1,
+        output_folder=output_folder,
+        show_groundtruth=show_groundtruth,
+        threshold=threshold,
+    )
-- 
GitLab