Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
Commits on Source (61)
Showing
with 1862 additions and 113 deletions
......@@ -25,3 +25,4 @@ _citools/
_work/
.mypy_cache/
.pytest_cache/
results/
......@@ -10,3 +10,17 @@ include:
variables:
GIT_SUBMODULE_STRATEGY: normal
GIT_SUBMODULE_DEPTH: 1
documentation:
before_script:
# for opencv-python dependence
- apt-get update && apt-get install -y libgl1-mesa-glx > /dev/null
- !reference [.snippets, doc-prepare]
tests:
before_script:
# for opencv-python dependence
- if [ "$TAG" == "docker" ]; then apt-get update && apt-get install -y libgl1-mesa-glx > /dev/null; fi
# for pytorch quirks on linux
- if [ "$TAG" == "docker" ]; then export OMP_NUM_THREADS=1; fi # Set the number of threads used to 1
- !reference [.snippets, test-prepare]
......@@ -33,7 +33,7 @@ repos:
--ignore-missing-imports,
]
- repo: https://github.com/asottile/pyupgrade
rev: v3.11.0
rev: v3.14.0
hooks:
- id: pyupgrade
args: [--py39-plus]
......
......@@ -24,18 +24,21 @@ requirements:
- python >=3.10
- pip
- click {{ click }}
- grad-cam {{ grad_cam }}
- matplotlib {{ matplotlib }}
- numpy {{ numpy }}
- pillow {{ pillow }}
- psutil {{ psutil }}
- pytorch {{ pytorch }}
- scikit-image {{ scikit_image }}
- scikit-learn {{ scikit_learn }}
- scipy {{ scipy }}
- tabulate {{ tabulate }}
- torchvision {{ torchvision }}
- tqdm {{ tqdm }}
- tensorboard {{ tensorboard }}
- lightning >=2.0.3
- lightning {{ lightning }}
- lightning >=2.1.0,!=2.1.3
- clapper
run:
- python >=3.10
......@@ -45,13 +48,15 @@ requirements:
- {{ pin_compatible('pillow') }}
- {{ pin_compatible('psutil') }}
- {{ pin_compatible('pytorch') }}
- {{ pin_compatible('scikit-image') }}
- {{ pin_compatible('scikit-learn') }}
- {{ pin_compatible('scipy') }}
- {{ pin_compatible('tabulate') }}
- {{ pin_compatible('torchvision') }}
- {{ pin_compatible('tqdm') }}
- {{ pin_compatible('tensorboard') }}
- {{ pin_compatible('lightning') }}
- {{ pin_compatible('lightning', max_pin='x.x') }}
- lightning >=2.1.0,!=2.1.3
- clapper
test:
......
......@@ -65,6 +65,20 @@ Functions to actuate on the data.
ptbench.engine.evaluator
.. _ptbench.api.saliency:
Saliency Map Generation and Analysis
------------------------------------
Engines to generate and analyze saliency mapping techniques.
.. autosummary::
:toctree: api/saliency
ptbench.engine.saliency.generator
ptbench.engine.saliency.completeness
ptbench.engine.saliency.interpretability
.. _ptbench.api.utils:
......
......@@ -45,6 +45,18 @@ We support two installation modes, through pip_, or mamba_ (conda).
mamba install -c https://www.idiap.ch/software/biosignal/conda/label/beta -c conda-forge ptbench
.. tip::
To force-install Nvidia GPU support on Linux machines, execute:
.. code:: sh
$ mamba install pytorch-gpu
# or, to force the Nvidia CUDA version (environments w/o Nvidia setup):
$ CONDA_OVERRIDE_CUDA=11.2 mamba install 'pytorch-gpu=*=cuda112*'
.. _ptbench.setup:
......
......@@ -62,6 +62,22 @@
Recognition, pages 2646–2655.
.. [TBX11K-SIMPLIFIED-2020] *Liu, Y., Wu, Y.-H., Ban, Y., Wang, H., and Cheng, M.-*,
**Rethinking computer-aided tuberculosis diagnosis.**,
**Rethinking computer-aided tuberculosis diagnosis**,
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern
Recognition, pages 2646–2655.
..
[GRADCAM-2015] *B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A.
Torralba*, **Learning Deep Features for Discriminative Localization**, In
2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). doi:
https://doi.org/10.1109/CVPR.2016.319.
.. [SCORECAM-2020] *H. Wang et al.*, **Score-CAM: Score-Weighted Visual
Explanations for Convolutional Neural Networks** 2020 IEEE/CVF Conference on
Computer Vision and Pattern Recognition Workshops (CVPRW), Seattle, WA, USA,
2020 pp. 111-119. doi: https://doi.org/10.1109/CVPRW50498.2020.00020
.. [ROAD-2022] *Y. Rong, T. Leemann, V. Borisov, G. Kasneci, and E. Kasneci*,
*A Consistent and Efficient Evaluation Strategy for Attribution Methods* in
Proceedings of the 39th International Conference on Machine Learning, PMLR,
Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html
......@@ -31,6 +31,7 @@ dependencies = [
"click",
"numpy",
"scipy",
"scikit-image",
"scikit-learn",
"tqdm",
"psutil",
......@@ -39,9 +40,9 @@ dependencies = [
"pillow",
"torch>=1.8",
"torchvision>=0.10",
"lightning>=2.0.3",
"pydantic <2.0,>=1.7.4", # temporary, until issue #31 is fixed
"lightning <2.2.0a0,>=2.1.0",
"tensorboard",
"grad-cam>=1.4.8",
]
[project.urls]
......
......@@ -89,8 +89,10 @@ class RawDataLoader(_BaseRawDataLoader):
basename,
)
# N.B.: NIH CXR-14 images are encoded as color PNGs
# N.B.: some NIH CXR-14 images are encoded as color PNGs with an alpha
# channel. Most, are grayscale PNGs
image = PIL.Image.open(os.path.join(self.datadir, file_path))
image = image.convert("L") # required for some images
tensor = to_tensor(image)
# use the code below to view generated images
......
......@@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import dataclasses
import importlib.resources
import os
import typing
import PIL.Image
import typing_extensions
from torch.utils.data._utils.collate import default_collate_fn_map
from torchvision.transforms.functional import to_tensor
from ptbench.data.datamodule import CachingDataModule
......@@ -22,22 +26,102 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
database."""
BoundingBoxAnnotation: typing.TypeAlias = tuple[int, int, int, int, int]
"""Location of TB radiological findings (latent or active)
@dataclasses.dataclass
class BoundingBox:
"""Location of radiological findings.
Objects of this type carry bounding-box information of radiological findings on
the original 512x512 pixel images of TBX11k. The radiological findings are
defined as such:
Objects of this type carry bounding-box location of radiological findings
on the original images of TBX11k. The radiological findings are defined as
such:
* 0/1: This labels the sign as latent TB (0), or active TB (1)
* xmin: horizontal position of bounding box upper-left corner, in pixels
* ymin: vertical position of bounding box upper-left corner, in pixels
* width: width of the bounding box, in pixels
* height: height of the bounding box, in pixels
"""
label: int
xmin: int
ymin: int
width: int
height: int
def area(self) -> int:
"""Computes the bounding box area.
Returns
-------
The area in square-pixels.
"""
return self.width * self.height
@property
def xmax(self) -> int:
return self.xmin + self.width - 1
@property
def ymax(self) -> int:
return self.ymin + self.height - 1
def intersection(self, other: typing_extensions.Self) -> int:
"""Computes the area intersection between bounding boxes.
Notice that screen geometry dictates is slightly different from
floating point metrics. Consider a 1D example for the evaluation of the
intersection:
* 2 points : x1 = 1 and x2 = 3, the distance is indeed x2-x1 = 2
* 2 pixels of index : i1 = 1 and i2 = 3, the segment from pixel i1 to
i2 contains 3 pixels ie l = i2 - i1 + 1
Parameters
----------
other
The other bounding box to check intersections for
Returns
-------
The area intersection between this and the other bounding-box in
square pixels.
"""
dx = min(self.xmax, other.xmax) - max(self.xmin, other.xmin) + 1
dy = min(self.ymax, other.ymax) - max(self.ymin, other.ymin) + 1
if dx >= 0 and dy >= 0:
return dx * dy
return 0
class BoundingBoxes(collections.abc.Sequence[BoundingBox]):
"""A collection of bounding boxes."""
def __init__(self, t: typing.Sequence[BoundingBox] = []):
self.t = tuple(t)
def __getitem__(self, index):
return self.t[index]
def __len__(self) -> int:
return len(self.t)
# We update the default collate function map to use our custom function as
# explained at:
# https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate
def _collate_boundingboxes_fn(batch, *, collate_fn_map=None):
"""Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes
objects."""
return batch
default_collate_fn_map.update({BoundingBoxes: _collate_boundingboxes_fn})
* 0/1: This sign is for latent TB (0), or active TB (1)
* xmin: horizontal position of bounding box upper-left corner, in pixels
* ymin: vertical position of bounding box upper-left corner, in pixels
* width: width of the bounding box, in pixels
* height: height of the bounding box, in pixels
"""
DatabaseSample: typing.TypeAlias = (
tuple[str, int] | tuple[str, int, list[BoundingBoxAnnotation]]
tuple[str, int] | tuple[str, int, tuple[tuple[int, int, int, int, int]]]
)
"""Type of objects in our JSON representation for this database.
......@@ -90,7 +174,7 @@ class RawDataLoader(_BaseRawDataLoader):
return tensor, dict(
label=sample[1],
name=sample[0],
radsign_bboxes=self.bbox_annotations(sample),
bounding_boxes=self.bounding_boxes(sample),
)
def label(self, sample: DatabaseSample) -> int:
......@@ -111,10 +195,8 @@ class RawDataLoader(_BaseRawDataLoader):
"""
return sample[1]
def bbox_annotations(
self, sample: DatabaseSample
) -> list[BoundingBoxAnnotation]:
"""Loads a single image sample label from the disk.
def bounding_boxes(self, sample: DatabaseSample) -> BoundingBoxes:
"""Loads image annotated bounding-boxes from the disk.
Parameters
----------
......@@ -129,7 +211,10 @@ class RawDataLoader(_BaseRawDataLoader):
-------
Bounding box annotations, if any available with the sample.
"""
return sample[2] if len(sample) > 2 else [] # type: ignore
if len(sample) > 2:
return BoundingBoxes([BoundingBox(*k) for k in sample[2]]) # type: ignore
return BoundingBoxes()
def make_split(basename: str) -> DatabaseSplit:
......
......@@ -257,7 +257,7 @@ class ElasticDeformation:
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
parallel: int = -2,
parallel: int = -1,
):
self.alpha: float = alpha
self.sigma: float = sigma
......
......@@ -580,8 +580,10 @@ class ConcatDataModule(lightning.LightningDataModule):
if value < 0:
num_workers = 0
else:
num_workers = value or multiprocessing.cpu_count()
self._dataloader_multiproc["num_workers"] = num_workers
if num_workers > 0 and sys.platform == "darwin":
......@@ -589,6 +591,10 @@ class ConcatDataModule(lightning.LightningDataModule):
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
# keep workers hanging around if we have multiple
if value >= 0:
self._dataloader_multiproc["persistent_workers"] = True
@property
def model_transforms(self) -> list[Transform] | None:
"""Transforms required to fit data into the model.
......@@ -717,7 +723,7 @@ class ConcatDataModule(lightning.LightningDataModule):
if self.cache_samples:
logger.info(
f"Loading dataset:`{name}` into memory (caching)."
f" Trade-off: CPU RAM: more | Disk: less"
f" Trade-off: CPU RAM usage: more | Disk I/O: less"
)
for split, loader in self.splits[name]:
datasets.append(
......@@ -731,7 +737,7 @@ class ConcatDataModule(lightning.LightningDataModule):
else:
logger.info(
f"Loading dataset:`{name}` without caching."
f" Trade-off: CPU RAM: less | Disk: more"
f" Trade-off: CPU RAM usage: less | Disk I/O: more"
)
for split, loader in self.splits[name]:
datasets.append(
......
......@@ -18,8 +18,12 @@ logger = logging.getLogger(__name__)
class LoggingCallback(lightning.pytorch.Callback):
"""Callback to log various training metrics and device information.
It ensures CSVLogger logs training and evaluation metrics on the same line
Note that a CSVLogger only accepts numerical values, and not strings.
Rationale:
1. Losses are logged at the end of every batch, accumulated and handled by
the lightning framework
2. Everything else is done at the end of a training or validation epoch and
mostly concerns runtime metrics such as memory and cpu/gpu utilisation.
Parameters
......@@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback):
def __init__(self, resource_monitor: ResourceMonitor):
super().__init__()
# lists of number of samples/batch and average losses
# - we use this later to compute overall epoch losses
self._training_epoch_loss: tuple[list[int], list[float]] = ([], [])
self._validation_epoch_loss: dict[
int, tuple[list[int], list[float]]
] = {}
# timers
self._start_training_time = 0.0
self._start_training_epoch_time = 0.0
......@@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
"""
self._start_training_epoch_time = time.time()
self._training_epoch_loss = ([], [])
def on_train_epoch_end(
self,
......@@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback):
# evaluates this training epoch total time, and log it
epoch_time = time.time() - self._start_training_epoch_time
# Compute overall training loss considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all of
# the same size. This way, the average of averages is the overall
# average.
self._to_log["loss/train"] = torch.mean(
torch.tensor(self._training_epoch_loss[0])
* torch.tensor(self._training_epoch_loss[1])
).item()
self._to_log["epoch-duration-seconds/train"] = epoch_time
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"]
self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore
metrics = self._resource_monitor.data
if metrics is not None:
......@@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback):
"missing."
)
# if no validation dataloaders, complete cycle by the end of the
# training epoch, by logging all values to the logger
self.on_cycle_end(trainer, pl_module)
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["cycle-time-seconds/train"] = overall_cycle_time
self._to_log["total-execution-time-seconds"] = (
time.time() - self._start_training_time
)
self._to_log["eta-seconds"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
)
# the "step" is the tensorboard jargon for "epoch" or "batch",
# depending on how we are logging - in a more general way, it simply
# means the relative time step.
self._to_log["step"] = float(trainer.current_epoch)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
pl_module.log_dict(self._to_log)
self._to_log = {}
def on_train_batch_end(
self,
......@@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback):
batch_idx
The relative number of the batch
"""
self._training_epoch_loss[0].append(batch[0].shape[0])
self._training_epoch_loss[1].append(outputs["loss"].item())
pl_module.log(
"loss/train",
outputs["loss"].item(),
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
)
def on_validation_epoch_start(
self,
......@@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback):
The lightning module that is being trained
"""
self._start_validation_epoch_time = time.time()
self._validation_epoch_loss = {}
def on_validation_epoch_end(
self,
......@@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback):
"missing."
)
# Compute overall validation losses considering batches and sizes
# We disconsider accumulate_grad_batches and assume they were all
# of the same size. This way, the average of averages is the
# overall average.
for key in sorted(self._validation_epoch_loss.keys()):
if key == 0:
name = "loss/validation"
else:
name = f"loss/validation-{key}"
self._to_log[name] = torch.mean(
torch.tensor(self._validation_epoch_loss[key][0])
* torch.tensor(self._validation_epoch_loss[key][1])
).item()
self._to_log["step"] = float(trainer.current_epoch)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
pl_module.log_dict(self._to_log)
self._to_log = {}
def on_validation_batch_end(
self,
......@@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback):
Index of the dataloader used during validation. Use this to figure
out which dataset was used for this validation epoch.
"""
size, value = self._validation_epoch_loss.setdefault(
dataloader_idx, ([], [])
)
size.append(batch[0].shape[0])
value.append(outputs.item())
def on_cycle_end(
self,
trainer: lightning.pytorch.Trainer,
pl_module: lightning.pytorch.LightningModule,
) -> None:
"""Called when the training/validation cycle has ended.
This function will log all relevant values to the various loggers. It
is supposed to be called by the end of the training cycle (consisting
of a training and validation step).
Parameters
----------
trainer
The Lightning trainer object
pl_module
The lightning module that is being trained
"""
# collect some final time for the whole training cycle
# Note: logging should happen at on_validation_end(), but
# apparently you can't log from there
overall_cycle_time = time.time() - self._start_training_epoch_time
self._to_log["cycle-time-seconds/train"] = overall_cycle_time
self._to_log["total-execution-time-seconds"] = (
time.time() - self._start_training_time
)
self._to_log["eta-seconds"] = overall_cycle_time * (
trainer.max_epochs - trainer.current_epoch # type: ignore
if dataloader_idx == 0:
key = "loss/validation"
else:
key = f"loss/validation-{dataloader_idx}"
pl_module.log(
key,
outputs.item(),
prog_bar=False,
on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
)
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for k in sorted(self._to_log.keys()):
pl_module.log_dict(
{k: self._to_log[k], "step": float(trainer.current_epoch)}
)
self._to_log = {}
......@@ -318,6 +318,8 @@ def aggregate_roc(
A dictionary mapping split names to ROC curve data produced by
:py:func:sklearn.metrics.roc_curve`.
title
The title of the plot.
Returns
-------
......@@ -471,6 +473,8 @@ def aggregate_pr(
data
A dictionary mapping split names to ROC curve data produced by
:py:func:sklearn.metrics.precision_recall_curve`.
title
The title of the plot.
Returns
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import functools
import logging
import multiprocessing
import typing
import lightning.pytorch
import numpy as np
import torch
import tqdm
from pytorch_grad_cam.metrics.road import (
ROADLeastRelevantFirstAverage,
ROADMostRelevantFirstAverage,
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from ...data.typing import Sample
from ...models.typing import SaliencyMapAlgorithm
from ..device import DeviceManager
logger = logging.getLogger(__name__)
class SigmoidClassifierOutputTarget(torch.nn.Module):
def __init__(self, category):
self.category = category
def __call__(self, model_output):
sigmoid_output = torch.sigmoid(model_output)
if len(sigmoid_output.shape) == 1:
return sigmoid_output[self.category]
return sigmoid_output[:, self.category]
def _calculate_road_scores(
model: lightning.pytorch.LightningModule,
images: torch.Tensor,
output_num: int,
saliency_map_callable: typing.Callable,
percentiles: typing.Sequence[int],
) -> tuple[float, float, float]:
"""Calculates average ROAD scores for different removal percentiles.
This function calculates ROAD scores by averaging the scores for
different removal (hardcoded) percentiles, for a single input image, a
given visualization method, a target class.
Parameters
----------
model
Neural network model (e.g. pasa).
images
A batch of input images to use evaluating the ROAD scores. Currently,
we only support batches with a single image.
output_num
Target output neuron to take into consideration when evaluating the
saliency maps and calculating ROAD scores
saliency_map_callable
A callable saliency-map generator from grad-cam
percentiles
A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores.
Returns
-------
A 3-tuple containing floating point numbers representing the
most-relevant-first average score (``morf``), least-relevant-first
average score (``lerf``) and the combined value (``(lerf-morf)/2``).
"""
saliency_map = saliency_map_callable(
input_tensor=images, targets=[ClassifierOutputTarget(output_num)]
)
cam_metric_ROADMoRF_avg = ROADMostRelevantFirstAverage(
percentiles=percentiles
)
cam_metric_ROADLeRF_avg = ROADLeastRelevantFirstAverage(
percentiles=percentiles
)
# Calculate ROAD scores for all percentiles and average - this is NOT the
# current processing bottleneck. If you want to optimise anyting, look at
# the evaluation of the perturbation using scipy.sparse at the
# NoisyLinearImputer, part of the grad-cam package (submodule
# ``metrics.road``.
metric_target = [SigmoidClassifierOutputTarget(output_num)]
MoRF_scores = cam_metric_ROADMoRF_avg(
input_tensor=images,
cams=saliency_map,
model=model,
targets=metric_target,
)
LeRF_scores = cam_metric_ROADLeRF_avg(
input_tensor=images,
cams=saliency_map,
model=model,
targets=metric_target,
)
return (
float(MoRF_scores.item()),
float(LeRF_scores.item()),
float(LeRF_scores.item() - MoRF_scores.item()) / 2.0,
)
def _process_sample(
sample: Sample,
model: lightning.pytorch.LightningModule,
device: torch.device,
saliency_map_callable: typing.Callable,
target_class: typing.Literal["highest", "all"],
positive_only: bool,
percentiles: typing.Sequence[int],
) -> list:
"""Helper function to :py:func:`run` to be used in multiprocessing
contexts.
Parameters
----------
model
Neural network model (e.g. pasa).
device
The device to process samples on.
saliency_map_callable
A callable saliency-map generator from grad-cam
target_class
Class to target for saliency estimation. Can be either set to
"all" or "highest". "highest".
positive only
If set, and the model chosen has a single output (binary), then
saliency maps will only be generated for samples of the positive class
percentiles
A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores.
"""
name: str = sample[1]["name"][0]
label: int = int(sample[1]["label"].item())
image = sample[0].to(device=device, non_blocking=torch.cuda.is_available())
# in binary classification systems, negative labels may be skipped
if positive_only and (model.num_classes == 1) and (label == 0):
return [name, label]
# chooses target outputs to generate saliency maps for
if model.num_classes > 1: # type: ignore
if target_class == "all":
# test all outputs
for output_num in range(model.num_classes): # type: ignore
results = _calculate_road_scores(
model,
image,
output_num,
saliency_map_callable,
percentiles,
)
return [name, label, output_num, *results]
else:
# we will figure out the output with the highest value and
# evaluate the saliency mapping technique over it.
outputs = saliency_map_callable.activations_and_grads(image) # type: ignore
output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1)
assert len(output_nums) == 1
results = _calculate_road_scores(
model,
image,
output_nums[0],
saliency_map_callable,
percentiles,
)
return [name, label, output_nums[0], *results]
# default route for binary classification
results = _calculate_road_scores(
model,
image,
0,
saliency_map_callable,
percentiles,
)
return [name, label, 0, *results]
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
saliency_map_algorithm: SaliencyMapAlgorithm,
target_class: typing.Literal["highest", "all"],
positive_only: bool,
percentiles: typing.Sequence[int],
parallel: int,
) -> dict[str, list[typing.Any]]:
"""Evaluates ROAD scores for all samples in a datamodule.
The ROAD algorithm was first described at [ROAD-2022]_. It estimates
explainability (in the completeness sense) of saliency maps by substituting
relevant pixels in the input image by a local average, and re-running
prediction on the altered image, and measuring changes in the output
classification score when said perturbations are in place. By substituting
most or least relevant pixels with surrounding averages, the ROAD algorithm
estimates the importance of such elements in the produced saliency map. As
2023, this measurement technique is considered to be one of the
state-of-the-art metrics of explainability.
This function returns a dictionary containing most-relevant-first (remove a
percentile of the most relevant pixels), least-relevant-first (remove a
percentile of the least relevant pixels), and combined ROAD evaluations per
sample for a particular saliency mapping algorithm.
Parameters
---------
model
Neural network model (e.g. pasa).
datamodule
The lightning datamodule to iterate on.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
saliency_map_algorithm
The algorithm for saliency map estimation to use.
target_class
(Use only with multi-label models) Which class to target for CAM
calculation. Can be either set to "all" or "highest". "highest" is
default, which means only saliency maps for the class with the highest
activation will be generated.
positive_only
If set, saliency maps will only be generated for positive samples (ie.
label == 1 in a binary classification task). This option is ignored on
a multi-class output model.
percentiles
A sequence of percentiles (percent x100) integer values indicating the
proportion of pixels to perturb in the original image to calculate both
MoRF and LeRF scores.
parallel
Use multiprocessing for data processing: if set to -1, disables
multiprocessing. Set to 0 to enable as many data processing instances
as processing cores as available in the system. Set to >= 1 to enable
that many multiprocessing instances for data processing.
Returns
-------
A dictionary where keys are dataset names in the provide datamodule,
and values are lists containing sample information alongside metrics
calculated:
* Sample name
* Sample target class
* The model output number used for the ROAD analysis (0, for binary
classifers as there is typically only one output).
* ``morf``: ROAD most-relevant-first average of percentiles 20, 40, 60 and
80 (a.k.a. AOPC-MoRF).
* ``lerf``: ROAD least-relevant-first average of percentiles 20, 40, 60 and
80 (a.k.a. AOPC-LeRF).
* combined: Average ROAD combined score by evaluating ``(lerf-morf)/2``
(a.k.a. AOPC-Combined).
"""
from ...models.densenet import Densenet
from ...models.pasa import Pasa
from .generator import _create_saliency_map_callable
if isinstance(model, Pasa):
if saliency_map_algorithm == "fullgrad":
raise ValueError(
"Fullgrad saliency map algorithm is not supported for the "
"Pasa model."
)
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet):
target_layers = [
model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore
]
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
use_cuda = device_manager.device_type == "cuda"
if device_manager.device_type in ("cuda", "mps") and (
parallel == 0 or parallel > 1
):
raise RuntimeError(
f"The number of multiprocessing instances is set to {parallel} and "
f"you asked to use a GPU (device = `{device_manager.device_type}`"
f"). The currently implementation can only handle a single GPU. "
f"Either disable GPU utilisation or set the number of "
f"multiprocessing instances to one, or disable multiprocessing "
"entirely (ie. set it to -1)."
)
# prepares model for evaluation, cast to target device
device = device_manager.torch_device()
model = model.to(device)
model.eval()
saliency_map_callable = _create_saliency_map_callable(
saliency_map_algorithm,
model,
target_layers, # type: ignore
use_cuda,
)
retval: dict[str, list[typing.Any]] = {}
# our worker function
_process = functools.partial(
_process_sample,
model=model,
device=device,
saliency_map_callable=saliency_map_callable,
target_class=target_class,
positive_only=positive_only,
percentiles=percentiles,
)
for k, v in datamodule.predict_dataloader().items():
retval[k] = []
if parallel < 0:
logger.info(
f"Computing ROAD scores for dataset `{k}` in the current "
f"process context..."
)
for sample in tqdm.tqdm(
v, desc="samples", leave=False, disable=None
):
retval[k].append(_process(sample))
else:
instances = parallel or multiprocessing.cpu_count()
logger.info(
f"Computing ROAD scores for dataset `{k}` using {instances} "
f"processes..."
)
with multiprocessing.Pool(instances) as p:
retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v)))
return retval
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import typing
import matplotlib.figure
import numpy
import numpy.typing
import tabulate
from ...models.typing import SaliencyMapAlgorithm
def _reconcile_metrics(
completeness: list,
interpretability: list,
) -> list[tuple[str, int, float, float, float]]:
"""Summarizes samples into a new table containing most important scores.
It returns a list containing a table with completeness and road scorse per
sample, for the selected dataset. Only samples for which a completness and
interpretability scores are availble are returned in the reconciled list.
Parameters
----------
completeness
A dictionary containing various tables with the sample name and
completness (ROAD) scores.
interpretability
A dictionary containing various tables with the sample name and
interpretability (Pro. Energy) scores.
Returns
-------
A list containing a table with the sample name, target label,
completeness score (Average ROAD across different ablation thresholds),
interpretability score (Proportional Energy), and the ROAD-Weighted
Proportional Energy score. The ROAD-Weighted Prop. Energy score is
defined as:
.. math::
\\text{ROAD-WeightedPropEng} = \\max(0, \\text{AvgROAD}) \\cdot
\\text{ProportionalEnergy}
"""
retval: list[tuple[str, int, float, float, float]] = []
retval = []
for compl_info, interp_info in zip(completeness, interpretability):
# ensure matching sample name and label
assert compl_info[0] == interp_info[0]
assert compl_info[1] == interp_info[1]
if len(compl_info) == len(interp_info) == 2:
# there is no data.
continue
aopc_combined = compl_info[5]
prop_energy = interp_info[2]
road_weighted_prop_energy = max(0, aopc_combined) * prop_energy
retval.append(
(
compl_info[0],
compl_info[1],
aopc_combined,
prop_energy,
road_weighted_prop_energy,
)
)
return retval
def _make_histogram(
name: str,
values: numpy.typing.NDArray,
xlim: tuple[float, float] | None = None,
title: None | str = None,
) -> matplotlib.figure.Figure:
"""Builds an histogram of values.
Parameters
----------
name
Name of the variable to be histogrammed (will appear in the figure)
values
Values to be histogrammed
xlim
A tuple representing the X-axis maximum and minimum to plot. If not
set, then use the bin boundaries.
title
A title to set on the histogram
Returns
-------
A matplotlib figure containing the histogram.
"""
from matplotlib import pyplot
fig, ax = pyplot.subplots(1)
ax = typing.cast(matplotlib.figure.Axes, ax)
ax.set_xlabel(name)
ax.set_ylabel("Frequency")
if title is not None:
ax.set_title(title)
else:
ax.set_title(f"{name} Frequency Histogram")
n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7)
if xlim is not None:
ax.spines.bottom.set_bounds(*xlim)
else:
ax.spines.bottom.set_bounds(bins[0], bins[-1])
ax.spines.left.set_bounds(0, n.max())
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3)
# draw median and quartiles
quartile = numpy.percentile(values, [25, 50, 75])
ax.axvline(
quartile[0], color="green", linestyle="--", label="Q1", alpha=0.5
)
ax.axvline(quartile[1], color="red", label="median", alpha=0.5)
ax.axvline(
quartile[2], color="green", linestyle="--", label="Q3", alpha=0.5
)
return fig # type: ignore
def summary_table(
summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str
) -> str:
"""Tabulates various summaries into one table.
Parameters
----------
summary
A dictionary mapping saliency algorithm names to the results of
:py:func:`run`.
fmt
One of the formats supported by `python-tabulate
<https://pypi.org/project/tabulate/>`_.
Returns
-------
A string containing the tabulated information.
"""
headers = [
"Algorithm",
"AOPC-Combined",
"Prop. Energy",
"ROAD-Normalised",
]
table = [
[
k,
v["aopc-combined"]["quartiles"][50],
v["proportional-energy"]["quartiles"][50],
v["road-normalised-proportional-energy-average"],
]
for k, v in summary.items()
]
return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f")
def _extract_statistics(
algo: SaliencyMapAlgorithm,
data: list[tuple[str, int, float, float, float]],
name: str,
index: int,
dataset: str,
xlim: tuple[float, float] | None = None,
) -> dict[str, typing.Any]:
"""Extracts all meaningful statistics from a reconciled statistics set.
Parameters
----------
algo
The algorithm for saliency map estimation that is being analysed.
data
A list of tuples each containing a sample name, target, and values
produced by completeness and interpretability analysis as returned by
:py:func:`_reconcile_metrics`.
name
The name of the variable being analysed
index
Which of the indexes on the tuples containing in ``data`` that should
be extracted.
dataset
The name of the dataset being analysed
xlim
Limits for histogram plotting
Returns
-------
A dictionary containing the following elements:
* ``values``: A list of values corresponding to the index on the data
* ``mean``: The mean of the value listdir
* ``stdev``: The standard deviation of the value list
* ``quartiles``: The 25%, 50% (median), and 75% quartile of values
* ``plot``: An histogram of values
* ``decreasing_scores``: A list of sample names and labels in
decreasing value.
"""
val = numpy.array([k[index] for k in data])
return dict(
values=val,
mean=val.mean(),
stdev=val.std(ddof=1), # unbiased estimator
quartiles={
25: numpy.percentile(val, 25), # type: ignore
50: numpy.median(val), # type: ignore
75: numpy.percentile(val, 75), # type: ignore
},
plot=_make_histogram(
name,
val,
xlim=xlim,
title=f"{name} Frequency Histogram ({algo} @ {dataset})",
),
decreasing_scores=[
(k[0], k[index])
for k in sorted(data, key=lambda x: x[index], reverse=True)
],
)
def run(
saliency_map_algorithm: SaliencyMapAlgorithm,
completeness: dict[str, list],
interpretability: dict[str, list],
) -> dict[str, typing.Any]:
"""Evaluates multiple saliency map algorithms and produces summarized
results.
Parameters
----------
saliency_map_algorithm
The algorithm for saliency map estimation that is being analysed.
completeness
A dictionary mapping dataset names to tables with the sample name and
completness (among which Average ROAD) scores.
interpretability
A dictionary mapping dataset names to tables with the sample name and
interpretability (among which Prop. Energy) scores.
Returns
-------
A dictionary with most important statistical values for the main
completeness (AOPC-Combined), interpretability (Prop. Energy), and a
combination of both (ROAD-Weighted Prop. Energy) scores.
"""
retval: dict = {}
for dataset, compl_data in completeness.items():
reconciled = _reconcile_metrics(compl_data, interpretability[dataset])
d = {}
d["aopc-combined"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="AOPC-Combined",
index=2,
dataset=dataset,
)
d["proportional-energy"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="Prop.Energy",
index=3,
dataset=dataset,
xlim=(0, 1),
)
d["road-weighted-proportional-energy"] = _extract_statistics(
algo=saliency_map_algorithm,
data=reconciled,
name="ROAD-weighted-Prop.Energy",
index=4,
dataset=dataset,
)
d["road-normalised-proportional-energy-average"] = sum(
retval["road-weighted-proportional-energy"]["val"]
) / sum([max(0, k) for k in retval["aopc-combined"]["val"]])
retval[dataset] = d
return retval
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import pathlib
import typing
import lightning.pytorch
import numpy
import torch
import torch.nn
import tqdm
from ...models.typing import SaliencyMapAlgorithm
from ..device import DeviceManager
logger = logging.getLogger(__name__)
def _create_saliency_map_callable(
algo_type: SaliencyMapAlgorithm,
model: torch.nn.Module,
target_layers: list[torch.nn.Module] | None,
use_cuda: bool,
):
"""Creates a class activation map (CAM) instance for a given model."""
import pytorch_grad_cam
match algo_type:
case "gradcam":
return pytorch_grad_cam.GradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "scorecam":
return pytorch_grad_cam.ScoreCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "fullgrad":
return pytorch_grad_cam.FullGrad(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "randomcam":
return pytorch_grad_cam.RandomCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "hirescam":
return pytorch_grad_cam.HiResCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "gradcamelementwise":
return pytorch_grad_cam.GradCAMElementWise(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "gradcam++", "gradcamplusplus":
return pytorch_grad_cam.GradCAMPlusPlus(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "xgradcam":
return pytorch_grad_cam.XGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "ablationcam":
assert (
target_layers is not None
), "AblationCAM cannot have target_layers=None"
return pytorch_grad_cam.AblationCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "eigencam":
return pytorch_grad_cam.EigenCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "eigengradcam":
return pytorch_grad_cam.EigenGradCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case "layercam":
return pytorch_grad_cam.LayerCAM(
model=model, target_layers=target_layers, use_cuda=use_cuda
)
case _:
raise ValueError(
f"Saliency map algorithm `{algo_type}` is not currently "
f"supported."
)
def _save_saliency_map(
output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor
) -> None:
"""Helper function to save a saliency map to disk.
Parameters
---------
output_folder
Directory in which the resulting saliency maps will be saved.
name
Name of the saved file.
saliency_map
A real-valued saliency-map that conveys regions used for
classification in the original sample.
"""
n = pathlib.Path(name)
(output_folder / n.parent).mkdir(parents=True, exist_ok=True)
numpy.save(output_folder / n.with_suffix(".npy"), saliency_map[0])
def run(
model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule,
device_manager: DeviceManager,
saliency_map_algorithm: SaliencyMapAlgorithm,
target_class: typing.Literal["highest", "all"],
positive_only: bool,
output_folder: pathlib.Path,
) -> None:
"""Applies saliency mapping techniques on input CXR, outputs pickled
saliency maps directly to disk.
Parameters
---------
model
Neural network model (e.g. pasa).
datamodule
The lightning datamodule to iterate on.
device_manager
An internal device representation, to be used for training and
validation. This representation can be converted into a pytorch device
or a torch lightning accelerator setup.
saliency_map_algorithm
The algorithm to use for saliency map estimation.
target_class
(Use only with multi-label models) Which class to target for CAM
calculation. Can be either set to "all" or "highest". "highest" is
default, which means only saliency maps for the class with the highest
activation will be generated.
positive_only
If set, saliency maps will only be generated for positive samples (ie.
label == 1 in a binary classification task). This option is ignored on
a multi-class output model.
output_folder
Where to save all the saliency maps (this path should exist before
this function is called)
"""
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from ...models.densenet import Densenet
from ...models.pasa import Pasa
if isinstance(model, Pasa):
if saliency_map_algorithm == "fullgrad":
raise ValueError(
"Fullgrad saliency map algorithm is not supported for the "
"Pasa model."
)
target_layers = [model.fc14] # Last non-1x1 Conv2d layer
elif isinstance(model, Densenet):
target_layers = [
model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore
]
else:
raise TypeError(f"Model of type `{type(model)}` is not yet supported.")
use_cuda = device_manager.device_type == "cuda"
# prepares model for evaluation, cast to target device
device = device_manager.torch_device()
model = model.to(device)
model.eval()
saliency_map_callable = _create_saliency_map_callable(
saliency_map_algorithm,
model,
target_layers, # type: ignore
use_cuda,
)
for k, v in datamodule.predict_dataloader().items():
logger.info(
f"Generating saliency maps for dataset `{k}` via `{saliency_map_algorithm}`..."
)
for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None):
name = sample[1]["name"][0]
label = sample[1]["label"].item()
image = sample[0].to(
device=device, non_blocking=torch.cuda.is_available()
)
# in binary classification systems, negative labels may be skipped
if positive_only and (model.num_classes == 1) and (label == 0):
continue
# chooses target outputs to generate saliency maps for
if model.num_classes > 1:
if target_class == "all":
# just blindly generate saliency maps for all outputs
# - make one directory for every target output and lay
# images there like in the original dataset.
for output_num in range(model.num_classes):
use_folder = output_folder / str(output_num)
saliency_map = saliency_map_callable(
input_tensor=image,
targets=[ClassifierOutputTarget(output_num)], # type: ignore
)
_save_saliency_map(use_folder, name, saliency_map) # type: ignore
else:
# pytorch-grad-cam will figure out the output with the
# highest value and produce a saliency map for it - we
# will save it to disk.
use_folder = output_folder / "highest-output"
saliency_map = saliency_map_callable(
input_tensor=image,
# setting `targets=None` will set target to the
# maximum output index using
# ClassifierOutputTarget(max_output_index)
targets=None, # type: ignore
)
_save_saliency_map(use_folder, name, saliency_map) # type: ignore
else:
# binary classification model with a single output - just
# lay all cams uniformily like the original dataset
saliency_map = saliency_map_callable(
input_tensor=image,
targets=[
ClassifierOutputTarget(0), # type: ignore
],
)
_save_saliency_map(output_folder, name, saliency_map) # type: ignore
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import pathlib
import typing
import lightning.pytorch
import numpy
import numpy.typing
import skimage.measure
import torch
import torchvision.ops
from tqdm import tqdm
from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
logger = logging.getLogger(__name__)
def _ordered_connected_components(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
threshold: float,
) -> list[numpy.typing.NDArray[numpy.bool_]]:
"""Calculates the largest connected components available on a saliency map
and return those as individual masks.
This implementation is based on [SCORECAM-2020]_:
1. Thresholding: The pixel values above ``threshold``% of max value are
kept in the original saliency map. Everything else is set to zero. The
value proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure.
2. The thresholded saliency map is transformed into a boolean array (ones
are attributed to all elements above the threshold.
3. We call :py:func:`skimage.metrics.label` to evaluate all connected
components and label those with distinct integers.
4. We histogram the labels and return one binary mask for each label,
sorted by decreasing size.
Parameters
----------
saliency_map
Input saliciency map whose connected components will be calculated
from.
threshold
Relative threshold to be used to zero parts of the original saliency
map. A value of 0.2 will zero all values in the saliency map that are
bellow 20% of the maximum value observed in the said map.
Returns
-------
A list of boolean masks, one for each connected component, ordered by
decreasing size. This list may be empty if the input ``saliency_map``
is all zeroes.
"""
# thresholds like [SCORECAM-2020]_
saliency_array = numpy.array(saliency_map)
thresholded_mask = (
saliency_array >= (threshold * saliency_array.max())
).astype(numpy.uint8)
# avoids an all zeroes mask being processed
if not numpy.any(thresholded_mask):
return []
labelled, n = skimage.measure.label(thresholded_mask, return_num=True) # type: ignore
retval = [labelled == k for k in range(1, n + 1)]
return sorted(retval, key=lambda x: x.sum(), reverse=True)
def _extract_bounding_box(
mask: numpy.typing.NDArray[numpy.bool_],
) -> BoundingBox:
"""Defines a bounding box surrounding a connected component mask.
Parameters
----------
mask
The connected component mask from whom extract the bounding box.
Returns
-------
A bounding box.
"""
x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[
0
]
return BoundingBox(-1, int(x), int(y), int(x2 - x + 1), int(y2 - y + 1))
def _compute_max_iou_and_ioda(
detected_box: BoundingBox,
gt_bboxes: BoundingBoxes,
) -> tuple[float, float]:
"""Will calculate how much of detected area lies in ground truth boxes.
If there are multiple gt boxes, the detected area will be calculated
for each gt box separately and the gt box with the highest
intersecting part will be used for the calculation.
Parameters
----------
detected_box
BoundingBox of the detected area.
gt_bboxes
Ground-truth bounding boxes in the format ``(x, y, width,
height)``.
Returns
-------
The max iou and ioda values.
"""
detected_area = detected_box.area()
if detected_area == 0:
return 0.0, 0.0
max_intersection = 0
max_gt_area = 0
for bbox in gt_bboxes:
intersection = bbox.intersection(detected_box)
if intersection > max_intersection:
max_intersection = intersection
max_gt_area = bbox.area()
if max_gt_area == 0 and max_intersection == 0:
# This case means no intersection was found, even though there are gt boxes
iou, ioda = 0.0, 0.0
else:
iou = max_intersection / (
detected_area + max_gt_area - max_intersection
)
ioda = max_intersection / detected_area
return iou, ioda
def _get_largest_bounding_boxes(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
n: int,
threshold: float = 0.2,
) -> list[BoundingBox]:
"""Returns the N largest connected components as bounding boxes in a
saliency map.
The return of values is subject to the value of ``threshold`` applied, as
well as on the saliency map itself. The number of objects found is also
affected by those parameters.
Parameters
----------
saliency_map
Input saliciency map whose connected components will be calculated
from.
n
The number of connected components to search for in the saliency map.
Connected components are then translated to bounding-box notation.
threshold
Relative threshold to be used to zero parts of the original saliency
map. A value of 0.2 will zero all values in the saliency map that are
bellow 20% of the maximum value observed in the said map.
Returns
-------
The N largest connected components as bounding boxes in a saliency map.
"""
retval: list[BoundingBox] = []
masks = _ordered_connected_components(saliency_map, threshold)
if masks:
retval += [_extract_bounding_box(k) for k in masks[:n]]
return retval
def _compute_simultaneous_iou_and_ioda(
detected_box: BoundingBox,
gt_bboxes: BoundingBoxes,
) -> tuple[float, float]:
"""Will calculate how much of detected area lies between ground truth
boxes.
This means that if there are multiple gt boxes, the detected area
will be compared to them simultaneously (and not to each gt box
separately).
Parameters
----------
detected_box
BoundingBox of the detected area.
gt_bboxes
Collection of bounding boxes of the ground-truth drawn as
``True`` values.
Returns
-------
The iou and ioda for the provided boxes.
"""
detected_area = detected_box.area()
if detected_area == 0:
return 0, 0
intersection = sum([k.intersection(detected_box) for k in gt_bboxes])
total_gt_area = sum([k.area() for k in gt_bboxes])
iou = intersection / (detected_area + total_gt_area - intersection)
ioda = intersection / detected_area
return float(iou), float(ioda)
def _compute_avg_saliency_focus(
saliency_map: numpy.typing.NDArray[numpy.double],
gt_mask: numpy.typing.NDArray[numpy.bool_],
) -> float:
"""Integrates the saliency map over the ground-truth boxes and normalizes
by total bounding-box area.
This function will integrate (sum) the value of the saliency map over the
ground-truth bounding boxes and normalize it by the total area covered by
all ground-truth bounding boxes.
Parameters
----------
saliency_map
A real-valued saliency-map that conveys regions used for
classification in the original sample.
gt_mask
Ground-truth mask containing the bounding boxes of the ground-truth
drawn as ``True`` values.
Returns
-------
A single floating-point number representing the Average saliency focus.
"""
area = gt_mask.sum()
if area == 0:
return 0.0
return numpy.sum(saliency_map * gt_mask) / area
def _compute_proportional_energy(
saliency_map: numpy.typing.NDArray[numpy.double],
gt_mask: numpy.typing.NDArray[numpy.bool_],
) -> float:
"""Calculates how much activation lies within the ground truth boxes
compared to the total sum of the activations (integral).
Parameters
----------
saliency_map
A real-valued saliency-map that conveys regions used for
classification in the original sample.
gt_mask
Ground-truth mask containing the bounding boxes of the ground-truth
drawn as ``True`` values.
Returns
-------
A single floating-point number representing the proportional energy.
"""
denominator = numpy.sum(saliency_map)
if denominator == 0.0:
return 0.0
return float(numpy.sum(saliency_map * gt_mask) / denominator) # type: ignore
def _compute_binary_mask(
gt_bboxes: BoundingBoxes,
saliency_map: numpy.typing.NDArray[numpy.double],
) -> numpy.typing.NDArray[numpy.bool_]:
"""Computes a binary mask for the saliency map using BoundingBoxes.
The binary_mask will be ON/True where the gt boxes are located.
Parameters
----------
gt_bboxes
Ground-truth bounding boxes in the format ``(x, y, width,
height)``.
saliency_map
A real-valued saliency-map that conveys regions used for
classification in the original sample.
Returns
-------
A numpy array of the same size as saliency_map with
the value False everywhere except at the positions inside
the bounding boxes, which will be True.
"""
binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_)
for bbox in gt_bboxes:
binary_mask[
bbox.ymin : bbox.ymin + bbox.height,
bbox.xmin : bbox.xmin + bbox.width,
] = True
return binary_mask
def _process_sample(
gt_bboxes: BoundingBoxes,
saliency_map: numpy.typing.NDArray[numpy.double],
) -> tuple[float, float]:
"""Calculates the metrics for a single sample.
Parameters
----------
gt_bboxes
A list of ground-truth bounding boxes.
saliency_map
A real-valued saliency-map that conveys regions used for
classification in the original sample.
Returns
-------
A tuple containing the following values:
* IoU
* IoDA
* Proportional energy
* Average saliency focus
* Largest detected bounding box
"""
# largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2)
# detected_box = (
# largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0)
# )
#
# # Calculate localization metrics
# iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
binary_mask = _compute_binary_mask(gt_bboxes, saliency_map)
return (
# iou,
# ioda,
_compute_proportional_energy(saliency_map, binary_mask),
_compute_avg_saliency_focus(saliency_map, binary_mask),
# (
# detected_box.xmin,
# detected_box.ymin,
# detected_box.width,
# detected_box.height,
# ),
)
def run(
input_folder: pathlib.Path,
target_label: int,
datamodule: lightning.pytorch.LightningDataModule,
) -> dict[str, list[typing.Any]]:
"""Applies visualization techniques on input CXR, outputs images with
overlaid heatmaps and csv files with measurements.
Parameters
---------
input_folder
Directory in which the saliency maps are stored for a specific
visualization type.
target_label
The label to target for evaluating interpretability metrics. Samples
contining any other label are ignored.
datamodule
The lightning datamodule to iterate on.
Returns
-------
A dictionary where keys are dataset names in the provide datamodule,
and values are lists containing sample information alongside metrics
calculated:
* Sample name (str)
* Sample target class (int)
* Proportional energy (float)
* Average saliency focus (float)
"""
retval: dict[str, list[typing.Any]] = {}
# TODO: This loads the images from the dataset, but they are not useful at
# this point. Possibly using the contents of ``datamodule.splits`` can
# substantially speed this up.
for dataset_name, dataset_loader in datamodule.predict_dataloader().items():
logger.info(
f"Estimating interpretability metrics for dataset `{dataset_name}`..."
)
retval[dataset_name] = []
for sample in tqdm(
dataset_loader, desc="batches", leave=False, disable=None
):
name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item())
if label != target_label:
# we add the entry for dataset completeness, but do not treat
# it
retval[dataset_name].append([name, label])
continue
# TODO: This is very specific to the TBX11k system for labelling
# regions of interest. We need to abstract from this to support more
# datasets and other ways to annotate.
bboxes: BoundingBoxes = sample[1].get(
"bounding_boxes", BoundingBoxes()
)
if not bboxes:
logger.warning(
f"Sample `{name}` does not contdain bounding-box information. "
f"No localization metrics can be calculated in this case. "
f"Skipping..."
)
# we add the entry for dataset completeness
retval[dataset_name].append([name, label])
continue
# we fully process this entry
retval[dataset_name].append(
[
name,
label,
*_process_sample(
bboxes[0],
numpy.load(
input_folder
/ pathlib.Path(name).with_suffix(".npy")
),
),
]
)
return retval
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
import pathlib
import typing
import lightning.pytorch
import matplotlib.pyplot
import numpy
import numpy.typing
import PIL.Image
import PIL.ImageColor
import PIL.ImageDraw
import torchvision.transforms.functional
from tqdm import tqdm
from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
logger = logging.getLogger(__name__)
def _overlay_saliency_map(
image: PIL.Image.Image,
saliencies: numpy.typing.NDArray[numpy.double],
colormap: typing.Literal[ # we accept any "Sequential" colormap from mpl
"viridis",
"plasma",
"inferno",
"magma",
"cividis",
"Greys",
"Purples",
"Blues",
"Greens",
"Oranges",
"Reds",
"YlOrBr",
"YlOrRd",
"OrRd",
"PuRd",
"RdPu",
"BuPu",
"GnBu",
"PuBu",
"YlGnBu",
"PuBuGn",
"BuGn",
"YlGn",
],
image_weight: float,
) -> PIL.Image.Image:
"""Creates an overlayed represention of the saliency map on the original
image.
This is a slightly modified version of the show_cam_on_image implementation in:
https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead
of opencv.
Parameters
----------
image
The input imge that will be overlayed with the saliency map
saliencies
The saliency map that will be overlaid on the (raw) image
colormap
The name of the (matplotlib) colormap to be used
image_weight
The final result is ``image_weight * image + (1-image_weight) *
saliency_map``.
Returns
-------
A modified version of the input ``image`` with the overlaid saliency
map.
"""
image_array = numpy.array(image, dtype=numpy.float32) / 255.0
assert image_array.shape[:2] == saliencies.shape, (
f"The shape of the saliency map ({saliencies.shape}) is different "
f"from the shape of the input image ({image_array.shape[:2]})."
)
assert (
saliencies.max() <= 1
), f"The input saliency map should be in the range [0, 1] (max={saliencies.max()})"
assert (
image_weight > 0 and image_weight < 1
), f"image_weight should be in the range [0, 1], but got {image_weight}"
heatmap = matplotlib.pyplot.cm.get_cmap(colormap)(saliencies)
# For pixels where the mask is zero, the original image pixels are being
# used without a mask.
result = numpy.where(
saliencies[..., numpy.newaxis] == 0,
image_array,
(image_weight * image_array) + ((1 - image_weight) * heatmap[:, :, :3]),
)
return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB")
def _overlay_bounding_box(
image: PIL.Image.Image,
bbox: BoundingBox,
color: str,
width: int,
) -> PIL.Image.Image:
"""Draws ground-truth on the input image.
Parameters
----------
image
The input imge that will be overlayed with the saliency map
bbox
The bounding box to draw on the input image
color
The color to use for drawing the bounding box. Any of the colours in
:any:`PIL.ImageColor.colormap` are accepted.
width
The width of the bounding box, in pixels. A larger value creates a
bounding box that is thicker, towards the outside of the boxed area.
Returns
-------
A modified version of the input ``image`` with the ground-truth drawn
on the top.
"""
draw = PIL.ImageDraw.Draw(image)
draw.rectangle(
(bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax),
outline=PIL.ImageColor.getrgb(color),
width=width,
)
return image
def _process_sample(
raw_data: numpy.typing.NDArray[numpy.double],
saliencies: numpy.typing.NDArray[numpy.double],
ground_truth: BoundingBoxes,
) -> PIL.Image.Image:
"""Generates an overlayed representation of the original sample and
saliency maps.
Parameters
----------
raw_data
The raw data representing the input sample that will be overlayed with
saliency maps and annotations
saliencies
The saliency map recovered from the model, that will be inprinted on
the raw_data
ground_truth
Ground-truth annotations that may be inprinted on the final image
Returns
-------
An image with the original raw data overlayed with the different
elements as selected by the user.
"""
# we need a colour image to eventually overlay a (coloured) saliency map on
# the top, draw rectangles and other annotations in coulour. So, we force
# it right up front.
retval = torchvision.transforms.functional.to_pil_image(raw_data).convert(
"RGB"
)
retval = _overlay_saliency_map(
retval, saliencies, colormap="plasma", image_weight=0.5
)
for k in ground_truth:
retval = _overlay_bounding_box(retval, k, color="green", width=2)
return retval
def run(
datamodule: lightning.pytorch.LightningDataModule,
input_folder: pathlib.Path,
target_label: int,
output_folder: pathlib.Path,
show_groundtruth: bool,
threshold: float,
):
"""Overlays saliency maps on CXR to output final images with heatmaps.
Parameters
----------
datamodule
The lightning datamodule to iterate on.
input_folder
Directory in which the saliency maps are stored for a specific
visualization type.
target_label
The label to target for evaluating interpretability metrics. Samples
contining any other label are ignored.
output_folder
Directory in which the resulting visualisations will be saved.
show_groundtruth
If set, inprint ground truth labels over the original image and
saliency maps.
threshold : float
The pixel values above ``threshold``% of max value are kept in the
original saliency map. Everything else is set to zero. The value
proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure.
"""
for dataset_name, dataset_loader in datamodule.predict_dataloader().items():
logger.info(
f"Generating visualisations for samples at dataset `{dataset_name}`..."
)
for sample in tqdm(
dataset_loader, desc="batches", leave=False, disable=None
):
name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item())
data = sample[0][0]
if label != target_label:
# no visualisation was generated
continue
saliencies = numpy.load(
input_folder / pathlib.Path(name).with_suffix(".npy")
)
saliencies[saliencies < (threshold * saliencies.max())] = 0
# TODO: This is very specific to the TBX11k system for labelling
# regions of interest. We need to abstract from this to support more
# datasets and other ways to annotate.
if show_groundtruth:
ground_truth = sample[1].get("bounding_boxes", BoundingBoxes())
else:
ground_truth = BoundingBoxes()
# we fully process this entry
image = _process_sample(
data,
saliencies,
ground_truth,
)
# Save image
output_file_path = output_folder / pathlib.Path(name).with_suffix(
".png"
)
os.makedirs(output_file_path.parent, exist_ok=True)
image.save(output_file_path)