Skip to content
Snippets Groups Projects
Commit 910e58c1 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[scripts.view_saliency] New viewer implementation based on pillow and matplotlib instead of opencv

parent 953abf62
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import pathlib
import typing
import lightning.pytorch
import matplotlib.pyplot
import numpy
import numpy.typing
import PIL.Image
import PIL.ImageColor
import PIL.ImageDraw
import torchvision.transforms.functional
from tqdm import tqdm
from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
logger = logging.getLogger(__name__)
def _overlay_saliency_map(
image: PIL.Image.Image,
saliencies: numpy.typing.NDArray[numpy.double],
colormap: typing.Literal[ # we accept any "Sequential" colormap from mpl
"viridis",
"plasma",
"inferno",
"magma",
"cividis",
"Greys",
"Purples",
"Blues",
"Greens",
"Oranges",
"Reds",
"YlOrBr",
"YlOrRd",
"OrRd",
"PuRd",
"RdPu",
"BuPu",
"GnBu",
"PuBu",
"YlGnBu",
"PuBuGn",
"BuGn",
"YlGn",
],
image_weight: float,
) -> PIL.Image.Image:
"""Creates an overlayed represention of the saliency map on the original
image.
This is a slightly modified version of the show_cam_on_image implementation in:
https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead
of opencv.
Parameters
----------
image
The input imge that will be overlayed with the saliency map
saliencies
The saliency map that will be overlaid on the (raw) image
colormap
The name of the (matplotlib) colormap to be used
image_weight
The final result is ``image_weight * image + (1-image_weight) *
saliency_map``.
Returns
-------
A modified version of the input ``image`` with the overlaid saliency
map.
"""
image_array = numpy.array(image, dtype=numpy.float32) / 255.0
assert image_array.shape[:2] == saliencies.shape, (
f"The shape of the saliency map ({saliencies.shape}) is different "
f"from the shape of the input image ({image_array.shape[:2]})."
)
assert (
saliencies.max() <= 1
), f"The input saliency map should be in the range [0, 1] (max={saliencies.max()})"
assert (
image_weight > 0 and image_weight < 1
), f"image_weight should be in the range [0, 1], but got {image_weight}"
heatmap = matplotlib.pyplot.cm.get_cmap(colormap)(saliencies)
# For pixels where the mask is zero, the original image pixels are being
# used without a mask.
result = numpy.where(
saliencies[..., numpy.newaxis] == 0,
image_array,
(image_weight * image_array) + ((1 - image_weight) * heatmap),
)
return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")
def _overlay_bounding_box(
image: PIL.Image.Image,
bbox: BoundingBox,
color: str,
width: int,
) -> PIL.Image.Image:
"""Draws ground-truth on the input image.
Parameters
----------
image
The input imge that will be overlayed with the saliency map
bbox
The bounding box to draw on the input image
color
The color to use for drawing the bounding box. Any of the colours in
:any:`PIL.ImageColor.colormap` are accepted.
width
The width of the bounding box, in pixels. A larger value creates a
bounding box that is thicker, towards the outside of the boxed area.
Returns
-------
A modified version of the input ``image`` with the ground-truth drawn
on the top.
"""
draw = PIL.ImageDraw.Draw(image)
draw.rectangle(
(bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax),
outline=PIL.ImageColor.getrgb(color),
width=width,
)
return image
def _process_sample(
raw_data: numpy.typing.NDArray[numpy.double],
saliencies: numpy.typing.NDArray[numpy.double],
ground_truth: BoundingBoxes,
) -> PIL.Image.Image:
"""Generates an overlayed representation of the original sample and
saliency maps.
Parameters
----------
raw_data
The raw data representing the input sample that will be overlayed with
saliency maps and annotations
saliencies
The saliency map recovered from the model, that will be inprinted on
the raw_data
ground_truth
Ground-truth annotations that may be inprinted on the final image
Returns
-------
An image with the original raw data overlayed with the different
elements as selected by the user.
"""
# we need a colour image to eventually overlay a (coloured) saliency map on
# the top, draw rectangles and other annotations in coulour. So, we force
# it right up front.
retval = torchvision.transforms.functional.to_pil_image(raw_data).convert(
"RGB"
)
retval = _overlay_saliency_map(
retval, saliencies, colormap="plasma", image_weight=0.5
)
for k in ground_truth:
retval = _overlay_bounding_box(retval, k, color="green", width=2)
return retval
def run(
datamodule: lightning.pytorch.LightningDataModule,
input_folder: pathlib.Path,
target_label: int,
output_folder: pathlib.Path,
show_groundtruth: bool,
threshold: float,
):
"""Overlays saliency maps on CXR to output final images with heatmaps.
Parameters
----------
datamodule
The lightning datamodule to iterate on.
input_folder
Directory in which the saliency maps are stored for a specific
visualization type.
target_label
The label to target for evaluating interpretability metrics. Samples
contining any other label are ignored.
output_folder
Directory in which the resulting visualisations will be saved.
show_groundtruth
If set, inprint ground truth labels over the original image and
saliency maps.
threshold : float
The pixel values above ``threshold``% of max value are kept in the
original saliency map. Everything else is set to zero. The value
proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure.
"""
for dataset_name, dataset_loader in datamodule.predict_dataloader().items():
logger.info(
f"Generating visualisations for samples at dataset `{dataset_name}`..."
)
for sample in tqdm(
dataset_loader, desc="batches", leave=False, disable=None
):
name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item())
data = sample[0][0]
if label != target_label:
# no visualisation was generated
continue
saliencies = numpy.load(
input_folder / pathlib.Path(name).with_suffix(".npy")
)
saliencies[saliencies < (threshold * saliencies.max())] = 0
# TODO: This is very specific to the TBX11k system for labelling
# regions of interest. We need to abstract from this to support more
# datasets and other ways to annotate.
if show_groundtruth:
ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())
else:
ground_truth = BoundingBoxes()
# we fully process this entry
image = _process_sample(
data,
saliencies,
ground_truth,
)
# Save image
output_file_path = output_folder / pathlib.Path(name).with_suffix(
".png"
)
os.makedirs(output_file_path.parent, exist_ok=True)
image.save(output_file_path)
...@@ -7,7 +7,6 @@ import click ...@@ -7,7 +7,6 @@ import click
from clapper.click import AliasedGroup from clapper.click import AliasedGroup
from . import ( from . import (
compare_vis,
config, config,
database, database,
evaluate, evaluate,
...@@ -19,7 +18,7 @@ from . import ( ...@@ -19,7 +18,7 @@ from . import (
saliency_interpretability, saliency_interpretability,
train, train,
train_analysis, train_analysis,
visualize, view_saliency,
) )
...@@ -32,7 +31,6 @@ def cli(): ...@@ -32,7 +31,6 @@ def cli():
pass pass
cli.add_command(compare_vis.compare_vis)
cli.add_command(config.config) cli.add_command(config.config)
cli.add_command(database.database) cli.add_command(database.database)
cli.add_command(evaluate.evaluate) cli.add_command(evaluate.evaluate)
...@@ -44,4 +42,4 @@ cli.add_command(generate_saliencymaps.generate_saliencymaps) ...@@ -44,4 +42,4 @@ cli.add_command(generate_saliencymaps.generate_saliencymaps)
cli.add_command(predict.predict) cli.add_command(predict.predict)
cli.add_command(train.train) cli.add_command(train.train)
cli.add_command(train_analysis.train_analysis) cli.add_command(train_analysis.train_analysis)
cli.add_command(visualize.visualize) cli.add_command(view_saliency.view_saliency)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
import pathlib
import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.command(
entry_point_group="ptbench.config",
cls=ConfigCommand,
epilog="""Examples:
\b
1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration:
.. code:: sh
ptbench visualize -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations
""",
)
@click.option(
"--model",
"-m",
help="A lightining module instance implementing the network to be used for applying the necessary data transformations.",
required=True,
cls=ResourceOption,
)
@click.option(
"--datamodule",
"-d",
help="A lighting data module containing the training, validation and test sets.",
required=True,
cls=ResourceOption,
)
@click.option(
"--input-folder",
"-i",
help="Path to the folder containing the saliency maps for a specific visualization type.",
required=True,
type=click.Path(
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
default="visualizations",
cls=ResourceOption,
)
@click.option(
"--output-folder",
"-o",
help="Path where to store the ROAD scores (created if does not exist)",
required=True,
type=click.Path(
file_okay=False,
dir_okay=True,
writable=True,
path_type=pathlib.Path,
),
default="visualizations",
cls=ResourceOption,
)
@click.option(
"--show-groundtruth/--no-show-groundtruth",
"-G/-g",
help="""If set, visualizations for ground truth labels will be generated.
Only works for datasets with bounding boxes.""",
is_flag=True,
default=False,
cls=ResourceOption,
)
@click.option(
"--threshold",
"-t",
help="""The pixel values above ``threshold``% of max value are kept in the
original saliency map. Everything else is set to zero. The value proposed
on [SCORECAM-2020]_ is 0.2. Use this value if unsure.""",
show_default=True,
required=True,
default=0.2,
type=click.FloatRange(min=0, max=1),
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def view_saliency(
model,
datamodule,
input_folder,
output_folder,
show_groundtruth,
threshold,
**_,
) -> None:
"""Generates heatmaps for input CXRs based on existing saliency maps."""
from ..engine.saliency.viewer import run
from .utils import save_sh_command
assert (
input_folder != output_folder
), "Output folder must not be the same as the input folder."
assert not str(output_folder).startswith(
str(input_folder)
), "Output folder must not be a subdirectory of the input folder."
logger.info(f"Output folder: {output_folder}")
os.makedirs(output_folder, exist_ok=True)
save_sh_command(output_folder / "command.sh")
datamodule.set_chunk_size(1, 1)
datamodule.drop_incomplete_batch = False
# datamodule.cache_samples = cache_samples
# datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="predict")
run(
datamodule=datamodule,
input_folder=input_folder,
target_label=1,
output_folder=output_folder,
show_groundtruth=show_groundtruth,
threshold=threshold,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment