diff --git a/src/ptbench/engine/saliency/viewer.py b/src/ptbench/engine/saliency/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..43011a385698f2a2dc3d2b19b256a2b4c1f7d1b2 --- /dev/null +++ b/src/ptbench/engine/saliency/viewer.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os +import pathlib +import typing + +import lightning.pytorch +import matplotlib.pyplot +import numpy +import numpy.typing +import PIL.Image +import PIL.ImageColor +import PIL.ImageDraw +import torchvision.transforms.functional + +from tqdm import tqdm + +from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes + +logger = logging.getLogger(__name__) + + +def _overlay_saliency_map( + image: PIL.Image.Image, + saliencies: numpy.typing.NDArray[numpy.double], + colormap: typing.Literal[ # we accept any "Sequential" colormap from mpl + "viridis", + "plasma", + "inferno", + "magma", + "cividis", + "Greys", + "Purples", + "Blues", + "Greens", + "Oranges", + "Reds", + "YlOrBr", + "YlOrRd", + "OrRd", + "PuRd", + "RdPu", + "BuPu", + "GnBu", + "PuBu", + "YlGnBu", + "PuBuGn", + "BuGn", + "YlGn", + ], + image_weight: float, +) -> PIL.Image.Image: + """Creates an overlayed represention of the saliency map on the original + image. + + This is a slightly modified version of the show_cam_on_image implementation in: + https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead + of opencv. + + + Parameters + ---------- + image + The input imge that will be overlayed with the saliency map + saliencies + The saliency map that will be overlaid on the (raw) image + colormap + The name of the (matplotlib) colormap to be used + image_weight + The final result is ``image_weight * image + (1-image_weight) * + saliency_map``. + + + Returns + ------- + A modified version of the input ``image`` with the overlaid saliency + map. + """ + + image_array = numpy.array(image, dtype=numpy.float32) / 255.0 + + assert image_array.shape[:2] == saliencies.shape, ( + f"The shape of the saliency map ({saliencies.shape}) is different " + f"from the shape of the input image ({image_array.shape[:2]})." + ) + + assert ( + saliencies.max() <= 1 + ), f"The input saliency map should be in the range [0, 1] (max={saliencies.max()})" + + assert ( + image_weight > 0 and image_weight < 1 + ), f"image_weight should be in the range [0, 1], but got {image_weight}" + + heatmap = matplotlib.pyplot.cm.get_cmap(colormap)(saliencies) + + # For pixels where the mask is zero, the original image pixels are being + # used without a mask. + result = numpy.where( + saliencies[..., numpy.newaxis] == 0, + image_array, + (image_weight * image_array) + ((1 - image_weight) * heatmap), + ) + + return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB") + + +def _overlay_bounding_box( + image: PIL.Image.Image, + bbox: BoundingBox, + color: str, + width: int, +) -> PIL.Image.Image: + """Draws ground-truth on the input image. + + Parameters + ---------- + image + The input imge that will be overlayed with the saliency map + bbox + The bounding box to draw on the input image + color + The color to use for drawing the bounding box. Any of the colours in + :any:`PIL.ImageColor.colormap` are accepted. + width + The width of the bounding box, in pixels. A larger value creates a + bounding box that is thicker, towards the outside of the boxed area. + + + Returns + ------- + A modified version of the input ``image`` with the ground-truth drawn + on the top. + """ + + draw = PIL.ImageDraw.Draw(image) + draw.rectangle( + (bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax), + outline=PIL.ImageColor.getrgb(color), + width=width, + ) + return image + + +def _process_sample( + raw_data: numpy.typing.NDArray[numpy.double], + saliencies: numpy.typing.NDArray[numpy.double], + ground_truth: BoundingBoxes, +) -> PIL.Image.Image: + """Generates an overlayed representation of the original sample and + saliency maps. + + Parameters + ---------- + raw_data + The raw data representing the input sample that will be overlayed with + saliency maps and annotations + saliencies + The saliency map recovered from the model, that will be inprinted on + the raw_data + ground_truth + Ground-truth annotations that may be inprinted on the final image + + + Returns + ------- + An image with the original raw data overlayed with the different + elements as selected by the user. + """ + + # we need a colour image to eventually overlay a (coloured) saliency map on + # the top, draw rectangles and other annotations in coulour. So, we force + # it right up front. + retval = torchvision.transforms.functional.to_pil_image(raw_data).convert( + "RGB" + ) + + retval = _overlay_saliency_map( + retval, saliencies, colormap="plasma", image_weight=0.5 + ) + + for k in ground_truth: + retval = _overlay_bounding_box(retval, k, color="green", width=2) + + return retval + + +def run( + datamodule: lightning.pytorch.LightningDataModule, + input_folder: pathlib.Path, + target_label: int, + output_folder: pathlib.Path, + show_groundtruth: bool, + threshold: float, +): + """Overlays saliency maps on CXR to output final images with heatmaps. + + Parameters + ---------- + datamodule + The lightning datamodule to iterate on. + input_folder + Directory in which the saliency maps are stored for a specific + visualization type. + target_label + The label to target for evaluating interpretability metrics. Samples + contining any other label are ignored. + output_folder + Directory in which the resulting visualisations will be saved. + show_groundtruth + If set, inprint ground truth labels over the original image and + saliency maps. + threshold : float + The pixel values above ``threshold``% of max value are kept in the + original saliency map. Everything else is set to zero. The value + proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure. + """ + + for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): + logger.info( + f"Generating visualisations for samples at dataset `{dataset_name}`..." + ) + + for sample in tqdm( + dataset_loader, desc="batches", leave=False, disable=None + ): + name = str(sample[1]["name"][0]) + label = int(sample[1]["label"].item()) + data = sample[0][0] + + if label != target_label: + # no visualisation was generated + continue + + saliencies = numpy.load( + input_folder / pathlib.Path(name).with_suffix(".npy") + ) + saliencies[saliencies < (threshold * saliencies.max())] = 0 + + # TODO: This is very specific to the TBX11k system for labelling + # regions of interest. We need to abstract from this to support more + # datasets and other ways to annotate. + if show_groundtruth: + ground_truth = sample[1].get("bounding_boxes", BoundingBoxes()) + else: + ground_truth = BoundingBoxes() + + # we fully process this entry + image = _process_sample( + data, + saliencies, + ground_truth, + ) + + # Save image + output_file_path = output_folder / pathlib.Path(name).with_suffix( + ".png" + ) + os.makedirs(output_file_path.parent, exist_ok=True) + image.save(output_file_path) diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index 0311bef5e705d9aeb466f2af4937e759bcf20b2f..cd7ff093379ab31fbf8eeeb49eacfdb239f2d8b1 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -7,7 +7,6 @@ import click from clapper.click import AliasedGroup from . import ( - compare_vis, config, database, evaluate, @@ -19,7 +18,7 @@ from . import ( saliency_interpretability, train, train_analysis, - visualize, + view_saliency, ) @@ -32,7 +31,6 @@ def cli(): pass -cli.add_command(compare_vis.compare_vis) cli.add_command(config.config) cli.add_command(database.database) cli.add_command(evaluate.evaluate) @@ -44,4 +42,4 @@ cli.add_command(generate_saliencymaps.generate_saliencymaps) cli.add_command(predict.predict) cli.add_command(train.train) cli.add_command(train_analysis.train_analysis) -cli.add_command(visualize.visualize) +cli.add_command(view_saliency.view_saliency) diff --git a/src/ptbench/scripts/view_saliency.py b/src/ptbench/scripts/view_saliency.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b9379f4bda40665999a330f9a157fb8a70cfbc --- /dev/null +++ b/src/ptbench/scripts/view_saliency.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os +import pathlib + +import click + +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +\b + 1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: + + .. code:: sh + + ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations + +""", +) +@click.option( + "--model", + "-m", + help="A lightining module instance implementing the network to be used for applying the necessary data transformations.", + required=True, + cls=ResourceOption, +) +@click.option( + "--datamodule", + "-d", + help="A lighting data module containing the training, validation and test sets.", + required=True, + cls=ResourceOption, +) +@click.option( + "--input-folder", + "-i", + help="Path to the folder containing the saliency maps for a specific visualization type.", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="visualizations", + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the ROAD scores (created if does not exist)", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="visualizations", + cls=ResourceOption, +) +@click.option( + "--show-groundtruth/--no-show-groundtruth", + "-G/-g", + help="""If set, visualizations for ground truth labels will be generated. + Only works for datasets with bounding boxes.""", + is_flag=True, + default=False, + cls=ResourceOption, +) +@click.option( + "--threshold", + "-t", + help="""The pixel values above ``threshold``% of max value are kept in the + original saliency map. Everything else is set to zero. The value proposed + on [SCORECAM-2020]_ is 0.2. Use this value if unsure.""", + show_default=True, + required=True, + default=0.2, + type=click.FloatRange(min=0, max=1), + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def view_saliency( + model, + datamodule, + input_folder, + output_folder, + show_groundtruth, + threshold, + **_, +) -> None: + """Generates heatmaps for input CXRs based on existing saliency maps.""" + + from ..engine.saliency.viewer import run + from .utils import save_sh_command + + assert ( + input_folder != output_folder + ), "Output folder must not be the same as the input folder." + + assert not str(output_folder).startswith( + str(input_folder) + ), "Output folder must not be a subdirectory of the input folder." + + logger.info(f"Output folder: {output_folder}") + os.makedirs(output_folder, exist_ok=True) + + save_sh_command(output_folder / "command.sh") + + datamodule.set_chunk_size(1, 1) + datamodule.drop_incomplete_batch = False + # datamodule.cache_samples = cache_samples + # datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="predict") + + run( + datamodule=datamodule, + input_folder=input_folder, + target_label=1, + output_folder=output_folder, + show_groundtruth=show_groundtruth, + threshold=threshold, + )