diff --git a/.gitignore b/.gitignore index 7f47e52f54351e1447dc8f4aa7ec4059d279db69..1599f06877666637eebc12420e0283d9f1c9bc99 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ _citools/ _work/ .mypy_cache/ .pytest_cache/ +results/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3aea0e42aae024e1f3336d280c457a7fbc11682b..c9ae0016c90b88658cbab481bf880ef729a38ce1 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,3 +10,17 @@ include: variables: GIT_SUBMODULE_STRATEGY: normal GIT_SUBMODULE_DEPTH: 1 + +documentation: + before_script: + # for opencv-python dependence + - apt-get update && apt-get install -y libgl1-mesa-glx > /dev/null + - !reference [.snippets, doc-prepare] + +tests: + before_script: + # for opencv-python dependence + - if [ "$TAG" == "docker" ]; then apt-get update && apt-get install -y libgl1-mesa-glx > /dev/null; fi + # for pytorch quirks on linux + - if [ "$TAG" == "docker" ]; then export OMP_NUM_THREADS=1; fi # Set the number of threads used to 1 + - !reference [.snippets, test-prepare] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6795d1df3fc3d4fde8e76b98c01ed2b3efc65ae7..eba7cd13524cc793b6d6d80af114e86b73c73079 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: --ignore-missing-imports, ] - repo: https://github.com/asottile/pyupgrade - rev: v3.11.0 + rev: v3.14.0 hooks: - id: pyupgrade args: [--py39-plus] diff --git a/conda/meta.yaml b/conda/meta.yaml index fd2999083224f0e51939c677240c9ac626614b9f..adcb300518eff7c3ac52da07d9de606dd6f7f0e1 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -24,18 +24,21 @@ requirements: - python >=3.10 - pip - click {{ click }} + - grad-cam {{ grad_cam }} - matplotlib {{ matplotlib }} - numpy {{ numpy }} - pillow {{ pillow }} - psutil {{ psutil }} - pytorch {{ pytorch }} + - scikit-image {{ scikit_image }} - scikit-learn {{ scikit_learn }} - scipy {{ scipy }} - tabulate {{ tabulate }} - torchvision {{ torchvision }} - tqdm {{ tqdm }} - tensorboard {{ tensorboard }} - - lightning >=2.0.3 + - lightning {{ lightning }} + - lightning >=2.1.0,!=2.1.3 - clapper run: - python >=3.10 @@ -45,13 +48,15 @@ requirements: - {{ pin_compatible('pillow') }} - {{ pin_compatible('psutil') }} - {{ pin_compatible('pytorch') }} + - {{ pin_compatible('scikit-image') }} - {{ pin_compatible('scikit-learn') }} - {{ pin_compatible('scipy') }} - {{ pin_compatible('tabulate') }} - {{ pin_compatible('torchvision') }} - {{ pin_compatible('tqdm') }} - {{ pin_compatible('tensorboard') }} - - {{ pin_compatible('lightning') }} + - {{ pin_compatible('lightning', max_pin='x.x') }} + - lightning >=2.1.0,!=2.1.3 - clapper test: diff --git a/doc/api.rst b/doc/api.rst index e166ddbe5fbedce84c4f2c3876dfbdfaf0239469..54457da74cf1a82fa0c28f554b6422482b61129f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -65,6 +65,20 @@ Functions to actuate on the data. ptbench.engine.evaluator +.. _ptbench.api.saliency: + +Saliency Map Generation and Analysis +------------------------------------ + +Engines to generate and analyze saliency mapping techniques. + +.. autosummary:: + :toctree: api/saliency + + ptbench.engine.saliency.generator + ptbench.engine.saliency.completeness + ptbench.engine.saliency.interpretability + .. _ptbench.api.utils: diff --git a/doc/install.rst b/doc/install.rst index 7826c8047002f036082ecc1a1e03858729ec3b6c..e07e4e4587a7d36a2b9685f8d30eb2d3e89a6006 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -45,6 +45,18 @@ We support two installation modes, through pip_, or mamba_ (conda). mamba install -c https://www.idiap.ch/software/biosignal/conda/label/beta -c conda-forge ptbench + .. tip:: + + To force-install Nvidia GPU support on Linux machines, execute: + + .. code:: sh + + $ mamba install pytorch-gpu + + # or, to force the Nvidia CUDA version (environments w/o Nvidia setup): + + $ CONDA_OVERRIDE_CUDA=11.2 mamba install 'pytorch-gpu=*=cuda112*' + .. _ptbench.setup: diff --git a/doc/references.rst b/doc/references.rst index d8758df81cf5a8bfdf714be7f060abbf6cd70125..a677685ca2b61a295e2d726d93de5d44251ce216 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -62,6 +62,22 @@ Recognition, pages 2646–2655. .. [TBX11K-SIMPLIFIED-2020] *Liu, Y., Wu, Y.-H., Ban, Y., Wang, H., and Cheng, M.-*, - **Rethinking computer-aided tuberculosis diagnosis.**, + **Rethinking computer-aided tuberculosis diagnosis**, In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2646–2655. + +.. + [GRADCAM-2015] *B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. + Torralba*, **Learning Deep Features for Discriminative Localization**, In + 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). doi: + https://doi.org/10.1109/CVPR.2016.319. + +.. [SCORECAM-2020] *H. Wang et al.*, **Score-CAM: Score-Weighted Visual + Explanations for Convolutional Neural Networks** 2020 IEEE/CVF Conference on + Computer Vision and Pattern Recognition Workshops (CVPRW), Seattle, WA, USA, + 2020 pp. 111-119. doi: https://doi.org/10.1109/CVPRW50498.2020.00020 + +.. [ROAD-2022] *Y. Rong, T. Leemann, V. Borisov, G. Kasneci, and E. Kasneci*, + *A Consistent and Efficient Evaluation Strategy for Attribution Methods* in + Proceedings of the 39th International Conference on Machine Learning, PMLR, + Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html diff --git a/pyproject.toml b/pyproject.toml index 16082e2d0c8a57b44d3d90a7aa417be960b78227..c80e518ff767e64a8b78bb3c933c23ffc1a78580 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "click", "numpy", "scipy", + "scikit-image", "scikit-learn", "tqdm", "psutil", @@ -39,9 +40,9 @@ dependencies = [ "pillow", "torch>=1.8", "torchvision>=0.10", - "lightning>=2.0.3", - "pydantic <2.0,>=1.7.4", # temporary, until issue #31 is fixed + "lightning <2.2.0a0,>=2.1.0", "tensorboard", + "grad-cam>=1.4.8", ] [project.urls] diff --git a/src/ptbench/config/data/nih_cxr14/datamodule.py b/src/ptbench/config/data/nih_cxr14/datamodule.py index 9875985f106fa4db9b90e82c0e6021c516f3e388..69f044eb701c58392e5b671db1853c03393ff08c 100644 --- a/src/ptbench/config/data/nih_cxr14/datamodule.py +++ b/src/ptbench/config/data/nih_cxr14/datamodule.py @@ -89,8 +89,10 @@ class RawDataLoader(_BaseRawDataLoader): basename, ) - # N.B.: NIH CXR-14 images are encoded as color PNGs + # N.B.: some NIH CXR-14 images are encoded as color PNGs with an alpha + # channel. Most, are grayscale PNGs image = PIL.Image.open(os.path.join(self.datadir, file_path)) + image = image.convert("L") # required for some images tensor = to_tensor(image) # use the code below to view generated images diff --git a/src/ptbench/config/data/tbx11k/datamodule.py b/src/ptbench/config/data/tbx11k/datamodule.py index 37ac5546fa865c16866940c3eada2641101e4c95..abe2bec7d35cf451fb98cca977fa69558935613b 100644 --- a/src/ptbench/config/data/tbx11k/datamodule.py +++ b/src/ptbench/config/data/tbx11k/datamodule.py @@ -2,12 +2,16 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import collections.abc +import dataclasses import importlib.resources import os import typing import PIL.Image +import typing_extensions +from torch.utils.data._utils.collate import default_collate_fn_map from torchvision.transforms.functional import to_tensor from ptbench.data.datamodule import CachingDataModule @@ -22,22 +26,102 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) database.""" -BoundingBoxAnnotation: typing.TypeAlias = tuple[int, int, int, int, int] -"""Location of TB radiological findings (latent or active) +@dataclasses.dataclass +class BoundingBox: + """Location of radiological findings. -Objects of this type carry bounding-box information of radiological findings on -the original 512x512 pixel images of TBX11k. The radiological findings are -defined as such: + Objects of this type carry bounding-box location of radiological findings + on the original images of TBX11k. The radiological findings are defined as + such: + + * 0/1: This labels the sign as latent TB (0), or active TB (1) + * xmin: horizontal position of bounding box upper-left corner, in pixels + * ymin: vertical position of bounding box upper-left corner, in pixels + * width: width of the bounding box, in pixels + * height: height of the bounding box, in pixels + """ + + label: int + xmin: int + ymin: int + width: int + height: int + + def area(self) -> int: + """Computes the bounding box area. + + Returns + ------- + The area in square-pixels. + """ + return self.width * self.height + + @property + def xmax(self) -> int: + return self.xmin + self.width - 1 + + @property + def ymax(self) -> int: + return self.ymin + self.height - 1 + + def intersection(self, other: typing_extensions.Self) -> int: + """Computes the area intersection between bounding boxes. + + Notice that screen geometry dictates is slightly different from + floating point metrics. Consider a 1D example for the evaluation of the + intersection: + + * 2 points : x1 = 1 and x2 = 3, the distance is indeed x2-x1 = 2 + * 2 pixels of index : i1 = 1 and i2 = 3, the segment from pixel i1 to + i2 contains 3 pixels ie l = i2 - i1 + 1 + + + Parameters + ---------- + other + The other bounding box to check intersections for + + Returns + ------- + The area intersection between this and the other bounding-box in + square pixels. + """ + dx = min(self.xmax, other.xmax) - max(self.xmin, other.xmin) + 1 + dy = min(self.ymax, other.ymax) - max(self.ymin, other.ymin) + 1 + + if dx >= 0 and dy >= 0: + return dx * dy + + return 0 + + +class BoundingBoxes(collections.abc.Sequence[BoundingBox]): + """A collection of bounding boxes.""" + + def __init__(self, t: typing.Sequence[BoundingBox] = []): + self.t = tuple(t) + + def __getitem__(self, index): + return self.t[index] + + def __len__(self) -> int: + return len(self.t) + + +# We update the default collate function map to use our custom function as +# explained at: +# https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate +def _collate_boundingboxes_fn(batch, *, collate_fn_map=None): + """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes + objects.""" + return batch + + +default_collate_fn_map.update({BoundingBoxes: _collate_boundingboxes_fn}) -* 0/1: This sign is for latent TB (0), or active TB (1) -* xmin: horizontal position of bounding box upper-left corner, in pixels -* ymin: vertical position of bounding box upper-left corner, in pixels -* width: width of the bounding box, in pixels -* height: height of the bounding box, in pixels -""" DatabaseSample: typing.TypeAlias = ( - tuple[str, int] | tuple[str, int, list[BoundingBoxAnnotation]] + tuple[str, int] | tuple[str, int, tuple[tuple[int, int, int, int, int]]] ) """Type of objects in our JSON representation for this database. @@ -90,7 +174,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict( label=sample[1], name=sample[0], - radsign_bboxes=self.bbox_annotations(sample), + bounding_boxes=self.bounding_boxes(sample), ) def label(self, sample: DatabaseSample) -> int: @@ -111,10 +195,8 @@ class RawDataLoader(_BaseRawDataLoader): """ return sample[1] - def bbox_annotations( - self, sample: DatabaseSample - ) -> list[BoundingBoxAnnotation]: - """Loads a single image sample label from the disk. + def bounding_boxes(self, sample: DatabaseSample) -> BoundingBoxes: + """Loads image annotated bounding-boxes from the disk. Parameters ---------- @@ -129,7 +211,10 @@ class RawDataLoader(_BaseRawDataLoader): ------- Bounding box annotations, if any available with the sample. """ - return sample[2] if len(sample) > 2 else [] # type: ignore + if len(sample) > 2: + return BoundingBoxes([BoundingBox(*k) for k in sample[2]]) # type: ignore + + return BoundingBoxes() def make_split(basename: str) -> DatabaseSplit: diff --git a/src/ptbench/data/augmentations.py b/src/ptbench/data/augmentations.py index a104ec4ac26b1540ed2a94a0509563fbeff0b4f8..f890c530ce4dc8a62c5df8159793c8e59b8007f4 100644 --- a/src/ptbench/data/augmentations.py +++ b/src/ptbench/data/augmentations.py @@ -257,7 +257,7 @@ class ElasticDeformation: spline_order: int = 1, mode: str = "nearest", p: float = 1.0, - parallel: int = -2, + parallel: int = -1, ): self.alpha: float = alpha self.sigma: float = sigma diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index e1ba75a3da01179e48b274d4f6d488c67a4f6bf7..82353be3fb2080b9e761dba9d2b52899316b9dd9 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -580,8 +580,10 @@ class ConcatDataModule(lightning.LightningDataModule): if value < 0: num_workers = 0 + else: num_workers = value or multiprocessing.cpu_count() + self._dataloader_multiproc["num_workers"] = num_workers if num_workers > 0 and sys.platform == "darwin": @@ -589,6 +591,10 @@ class ConcatDataModule(lightning.LightningDataModule): "multiprocessing_context" ] = multiprocessing.get_context("spawn") + # keep workers hanging around if we have multiple + if value >= 0: + self._dataloader_multiproc["persistent_workers"] = True + @property def model_transforms(self) -> list[Transform] | None: """Transforms required to fit data into the model. @@ -717,7 +723,7 @@ class ConcatDataModule(lightning.LightningDataModule): if self.cache_samples: logger.info( f"Loading dataset:`{name}` into memory (caching)." - f" Trade-off: CPU RAM: more | Disk: less" + f" Trade-off: CPU RAM usage: more | Disk I/O: less" ) for split, loader in self.splits[name]: datasets.append( @@ -731,7 +737,7 @@ class ConcatDataModule(lightning.LightningDataModule): else: logger.info( f"Loading dataset:`{name}` without caching." - f" Trade-off: CPU RAM: less | Disk: more" + f" Trade-off: CPU RAM usage: less | Disk I/O: more" ) for split, loader in self.splits[name]: datasets.append( diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 8669718ea8efc5085c090af1db8b82fd89e16e4f..6966f6feeec2d3fa11af1b12acd950b9968ffafe 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -18,8 +18,12 @@ logger = logging.getLogger(__name__) class LoggingCallback(lightning.pytorch.Callback): """Callback to log various training metrics and device information. - It ensures CSVLogger logs training and evaluation metrics on the same line - Note that a CSVLogger only accepts numerical values, and not strings. + Rationale: + + 1. Losses are logged at the end of every batch, accumulated and handled by + the lightning framework + 2. Everything else is done at the end of a training or validation epoch and + mostly concerns runtime metrics such as memory and cpu/gpu utilisation. Parameters @@ -33,13 +37,6 @@ class LoggingCallback(lightning.pytorch.Callback): def __init__(self, resource_monitor: ResourceMonitor): super().__init__() - # lists of number of samples/batch and average losses - # - we use this later to compute overall epoch losses - self._training_epoch_loss: tuple[list[int], list[float]] = ([], []) - self._validation_epoch_loss: dict[ - int, tuple[list[int], list[float]] - ] = {} - # timers self._start_training_time = 0.0 self._start_training_epoch_time = 0.0 @@ -101,7 +98,6 @@ class LoggingCallback(lightning.pytorch.Callback): The lightning module that is being trained """ self._start_training_epoch_time = time.time() - self._training_epoch_loss = ([], []) def on_train_epoch_end( self, @@ -132,17 +128,8 @@ class LoggingCallback(lightning.pytorch.Callback): # evaluates this training epoch total time, and log it epoch_time = time.time() - self._start_training_epoch_time - # Compute overall training loss considering batches and sizes - # We disconsider accumulate_grad_batches and assume they were all of - # the same size. This way, the average of averages is the overall - # average. - self._to_log["loss/train"] = torch.mean( - torch.tensor(self._training_epoch_loss[0]) - * torch.tensor(self._training_epoch_loss[1]) - ).item() - self._to_log["epoch-duration-seconds/train"] = epoch_time - self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] + self._to_log["learning-rate"] = pl_module.optimizers().defaults["lr"] # type: ignore metrics = self._resource_monitor.data if metrics is not None: @@ -155,9 +142,23 @@ class LoggingCallback(lightning.pytorch.Callback): "missing." ) - # if no validation dataloaders, complete cycle by the end of the - # training epoch, by logging all values to the logger - self.on_cycle_end(trainer, pl_module) + overall_cycle_time = time.time() - self._start_training_epoch_time + self._to_log["cycle-time-seconds/train"] = overall_cycle_time + self._to_log["total-execution-time-seconds"] = ( + time.time() - self._start_training_time + ) + self._to_log["eta-seconds"] = overall_cycle_time * ( + trainer.max_epochs - trainer.current_epoch # type: ignore + ) + # the "step" is the tensorboard jargon for "epoch" or "batch", + # depending on how we are logging - in a more general way, it simply + # means the relative time step. + self._to_log["step"] = float(trainer.current_epoch) + + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + pl_module.log_dict(self._to_log) + self._to_log = {} def on_train_batch_end( self, @@ -198,8 +199,14 @@ class LoggingCallback(lightning.pytorch.Callback): batch_idx The relative number of the batch """ - self._training_epoch_loss[0].append(batch[0].shape[0]) - self._training_epoch_loss[1].append(outputs["loss"].item()) + pl_module.log( + "loss/train", + outputs["loss"].item(), + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch[0].shape[0], + ) def on_validation_epoch_start( self, @@ -229,7 +236,6 @@ class LoggingCallback(lightning.pytorch.Callback): The lightning module that is being trained """ self._start_validation_epoch_time = time.time() - self._validation_epoch_loss = {} def on_validation_epoch_end( self, @@ -271,20 +277,12 @@ class LoggingCallback(lightning.pytorch.Callback): "missing." ) - # Compute overall validation losses considering batches and sizes - # We disconsider accumulate_grad_batches and assume they were all - # of the same size. This way, the average of averages is the - # overall average. - for key in sorted(self._validation_epoch_loss.keys()): - if key == 0: - name = "loss/validation" - else: - name = f"loss/validation-{key}" - - self._to_log[name] = torch.mean( - torch.tensor(self._validation_epoch_loss[key][0]) - * torch.tensor(self._validation_epoch_loss[key][1]) - ).item() + self._to_log["step"] = float(trainer.current_epoch) + + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + pl_module.log_dict(self._to_log) + self._to_log = {} def on_validation_batch_end( self, @@ -330,50 +328,17 @@ class LoggingCallback(lightning.pytorch.Callback): Index of the dataloader used during validation. Use this to figure out which dataset was used for this validation epoch. """ - size, value = self._validation_epoch_loss.setdefault( - dataloader_idx, ([], []) - ) - size.append(batch[0].shape[0]) - value.append(outputs.item()) - - def on_cycle_end( - self, - trainer: lightning.pytorch.Trainer, - pl_module: lightning.pytorch.LightningModule, - ) -> None: - """Called when the training/validation cycle has ended. - - This function will log all relevant values to the various loggers. It - is supposed to be called by the end of the training cycle (consisting - of a training and validation step). - - - Parameters - ---------- - - trainer - The Lightning trainer object - - pl_module - The lightning module that is being trained - """ - # collect some final time for the whole training cycle - # Note: logging should happen at on_validation_end(), but - # apparently you can't log from there - overall_cycle_time = time.time() - self._start_training_epoch_time - self._to_log["cycle-time-seconds/train"] = overall_cycle_time - self._to_log["total-execution-time-seconds"] = ( - time.time() - self._start_training_time - ) - self._to_log["eta-seconds"] = overall_cycle_time * ( - trainer.max_epochs - trainer.current_epoch # type: ignore + if dataloader_idx == 0: + key = "loss/validation" + else: + key = f"loss/validation-{dataloader_idx}" + + pl_module.log( + key, + outputs.item(), + prog_bar=False, + on_step=False, + on_epoch=True, + batch_size=batch[0].shape[0], ) - - # Do not log during sanity check as results are not relevant - if not trainer.sanity_checking: - for k in sorted(self._to_log.keys()): - pl_module.log_dict( - {k: self._to_log[k], "step": float(trainer.current_epoch)} - ) - self._to_log = {} diff --git a/src/ptbench/engine/saliency/__init__.py b/src/ptbench/engine/saliency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ptbench/engine/saliency/completeness.py b/src/ptbench/engine/saliency/completeness.py new file mode 100644 index 0000000000000000000000000000000000000000..2d42f21a7fbe84cc60faddf6a4923db8f7218292 --- /dev/null +++ b/src/ptbench/engine/saliency/completeness.py @@ -0,0 +1,333 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import functools +import logging +import multiprocessing +import typing + +import lightning.pytorch +import numpy as np +import torch +import tqdm + +from pytorch_grad_cam.metrics.road import ( + ROADLeastRelevantFirstAverage, + ROADMostRelevantFirstAverage, +) +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + +from ...data.typing import Sample +from ...models.typing import SaliencyMapAlgorithm +from ..device import DeviceManager + +logger = logging.getLogger(__name__) + + +class SigmoidClassifierOutputTarget(torch.nn.Module): + def __init__(self, category): + self.category = category + + def __call__(self, model_output): + sigmoid_output = torch.sigmoid(model_output) + if len(sigmoid_output.shape) == 1: + return sigmoid_output[self.category] + return sigmoid_output[:, self.category] + + +def _calculate_road_scores( + model: lightning.pytorch.LightningModule, + images: torch.Tensor, + output_num: int, + saliency_map_callable: typing.Callable, + percentiles: typing.Sequence[int], +) -> tuple[float, float, float]: + """Calculates average ROAD scores for different removal percentiles. + + This function calculates ROAD scores by averaging the scores for + different removal (hardcoded) percentiles, for a single input image, a + given visualization method, a target class. + + + Parameters + ---------- + model + Neural network model (e.g. pasa). + images + A batch of input images to use evaluating the ROAD scores. Currently, + we only support batches with a single image. + output_num + Target output neuron to take into consideration when evaluating the + saliency maps and calculating ROAD scores + saliency_map_callable + A callable saliency-map generator from grad-cam + percentiles + A sequence of percentiles (percent x100) integer values indicating the + proportion of pixels to perturb in the original image to calculate both + MoRF and LeRF scores. + + + Returns + ------- + A 3-tuple containing floating point numbers representing the + most-relevant-first average score (``morf``), least-relevant-first + average score (``lerf``) and the combined value (``(lerf-morf)/2``). + """ + saliency_map = saliency_map_callable( + input_tensor=images, targets=[ClassifierOutputTarget(output_num)] + ) + + cam_metric_ROADMoRF_avg = ROADMostRelevantFirstAverage( + percentiles=percentiles + ) + cam_metric_ROADLeRF_avg = ROADLeastRelevantFirstAverage( + percentiles=percentiles + ) + + # Calculate ROAD scores for all percentiles and average - this is NOT the + # current processing bottleneck. If you want to optimise anyting, look at + # the evaluation of the perturbation using scipy.sparse at the + # NoisyLinearImputer, part of the grad-cam package (submodule + # ``metrics.road``. + metric_target = [SigmoidClassifierOutputTarget(output_num)] + + MoRF_scores = cam_metric_ROADMoRF_avg( + input_tensor=images, + cams=saliency_map, + model=model, + targets=metric_target, + ) + + LeRF_scores = cam_metric_ROADLeRF_avg( + input_tensor=images, + cams=saliency_map, + model=model, + targets=metric_target, + ) + + return ( + float(MoRF_scores.item()), + float(LeRF_scores.item()), + float(LeRF_scores.item() - MoRF_scores.item()) / 2.0, + ) + + +def _process_sample( + sample: Sample, + model: lightning.pytorch.LightningModule, + device: torch.device, + saliency_map_callable: typing.Callable, + target_class: typing.Literal["highest", "all"], + positive_only: bool, + percentiles: typing.Sequence[int], +) -> list: + """Helper function to :py:func:`run` to be used in multiprocessing + contexts.""" + + name: str = sample[1]["name"][0] + label: int = int(sample[1]["label"].item()) + image = sample[0].to(device=device, non_blocking=torch.cuda.is_available()) + + # in binary classification systems, negative labels may be skipped + if positive_only and (model.num_classes == 1) and (label == 0): + return [name, label] + + # chooses target outputs to generate saliency maps for + if model.num_classes > 1: # type: ignore + if target_class == "all": + # test all outputs + for output_num in range(model.num_classes): # type: ignore + results = _calculate_road_scores( + model, + image, + output_num, + saliency_map_callable, + percentiles, + ) + return [name, label, output_num, *results] + + else: + # we will figure out the output with the highest value and + # evaluate the saliency mapping technique over it. + outputs = saliency_map_callable.activations_and_grads(image) # type: ignore + output_nums = np.argmax(outputs.cpu().data.numpy(), axis=-1) + assert len(output_nums) == 1 + results = _calculate_road_scores( + model, + image, + output_nums[0], + saliency_map_callable, + percentiles, + ) + return [name, label, output_nums[0], *results] + + # default route for binary classification + results = _calculate_road_scores( + model, + image, + 0, + saliency_map_callable, + percentiles, + ) + return [name, label, 0, *results] + + +def run( + model: lightning.pytorch.LightningModule, + datamodule: lightning.pytorch.LightningDataModule, + device_manager: DeviceManager, + saliency_map_algorithm: SaliencyMapAlgorithm, + target_class: typing.Literal["highest", "all"], + positive_only: bool, + percentiles: typing.Sequence[int], + parallel: int, +) -> dict[str, list[typing.Any]]: + """Evaluates ROAD scores for all samples in a datamodule. + + The ROAD algorithm was first described at [ROAD-2022]_. It estimates + explainability (in the completeness sense) of saliency maps by substituting + relevant pixels in the input image by a local average, and re-running + prediction on the altered image, and measuring changes in the output + classification score when said perturbations are in place. By substituting + most or least relevant pixels with surrounding averages, the ROAD algorithm + estimates the importance of such elements in the produced saliency map. As + 2023, this measurement technique is considered to be one of the + state-of-the-art metrics of explainability. + + This function returns a dictionary containing most-relevant-first (remove a + percentile of the most relevant pixels), least-relevant-first (remove a + percentile of the least relevant pixels), and combined ROAD evaluations per + sample for a particular saliency mapping algorithm. + + + Parameters + --------- + model + Neural network model (e.g. pasa). + datamodule + The lightning datamodule to iterate on. + device_manager + An internal device representation, to be used for training and + validation. This representation can be converted into a pytorch device + or a torch lightning accelerator setup. + saliency_map_algorithm + The algorithm for saliency map estimation to use. + target_class + (Use only with multi-label models) Which class to target for CAM + calculation. Can be either set to "all" or "highest". "highest" is + default, which means only saliency maps for the class with the highest + activation will be generated. + positive_only + If set, saliency maps will only be generated for positive samples (ie. + label == 1 in a binary classification task). This option is ignored on + a multi-class output model. + percentiles + A sequence of percentiles (percent x100) integer values indicating the + proportion of pixels to perturb in the original image to calculate both + MoRF and LeRF scores. + parallel + Use multiprocessing for data processing: if set to -1, disables + multiprocessing. Set to 0 to enable as many data processing instances + as processing cores as available in the system. Set to >= 1 to enable + that many multiprocessing instances for data processing. + + + Returns + ------- + + A dictionary where keys are dataset names in the provide datamodule, + and values are lists containing sample information alongside metrics + calculated: + + * Sample name + * Sample target class + * The model output number used for the ROAD analysis (0, for binary + classifers as there is typically only one output). + * ``morf``: ROAD most-relevant-first average of percentiles 20, 40, 60 and + 80 (a.k.a. AOPC-MoRF). + * ``lerf``: ROAD least-relevant-first average of percentiles 20, 40, 60 and + 80 (a.k.a. AOPC-LeRF). + * combined: Average ROAD combined score by evaluating ``(lerf-morf)/2`` + (a.k.a. AOPC-Combined). + """ + + from ...models.densenet import Densenet + from ...models.pasa import Pasa + from .generator import _create_saliency_map_callable + + if isinstance(model, Pasa): + if saliency_map_algorithm == "fullgrad": + raise ValueError( + "Fullgrad saliency map algorithm is not supported for the " + "Pasa model." + ) + target_layers = [model.fc14] # Last non-1x1 Conv2d layer + elif isinstance(model, Densenet): + target_layers = [ + model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore + ] + else: + raise TypeError(f"Model of type `{type(model)}` is not yet supported.") + + use_cuda = device_manager.device_type == "cuda" + if device_manager.device_type in ("cuda", "mps") and ( + parallel == 0 or parallel > 1 + ): + raise RuntimeError( + f"The number of multiprocessing instances is set to {parallel} and " + f"you asked to use a GPU (device = `{device_manager.device_type}`" + f"). The currently implementation can only handle a single GPU. " + f"Either disable GPU utilisation or set the number of " + f"multiprocessing instances to one, or disable multiprocessing " + "entirely (ie. set it to -1)." + ) + + # prepares model for evaluation, cast to target device + device = device_manager.torch_device() + model = model.to(device) + model.eval() + + saliency_map_callable = _create_saliency_map_callable( + saliency_map_algorithm, + model, + target_layers, # type: ignore + use_cuda, + ) + + retval: dict[str, list[typing.Any]] = {} + + # our worker function + _process = functools.partial( + _process_sample, + model=model, + device=device, + saliency_map_callable=saliency_map_callable, + target_class=target_class, + positive_only=positive_only, + percentiles=percentiles, + ) + + for k, v in datamodule.predict_dataloader().items(): + retval[k] = [] + + if parallel < 0: + logger.info( + f"Computing ROAD scores for dataset `{k}` in the current " + f"process context..." + ) + for sample in tqdm.tqdm( + v, desc="samples", leave=False, disable=None + ): + retval[k].append(_process(sample)) + + else: + instances = parallel or multiprocessing.cpu_count() + logger.info( + f"Computing ROAD scores for dataset `{k}` using {instances} " + f"processes..." + ) + with multiprocessing.Pool(instances) as p: + retval[k] = list(tqdm.tqdm(p.imap(_process, v), total=len(v))) + + return retval diff --git a/src/ptbench/engine/saliency/evaluator.py b/src/ptbench/engine/saliency/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a9c04eb46cfc4b17003b2ccbe1873d6e987a4b --- /dev/null +++ b/src/ptbench/engine/saliency/evaluator.py @@ -0,0 +1,307 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import typing + +import matplotlib.figure +import numpy +import numpy.typing +import tabulate + +from ...models.typing import SaliencyMapAlgorithm + + +def _reconcile_metrics( + completeness: list, + interpretability: list, +) -> list[tuple[str, int, float, float, float]]: + """Summarizes samples into a new table containing most important scores. + + It returns a list containing a table with completeness and road scorse per + sample, for the selected dataset. Only samples for which a completness and + interpretability scores are availble are returned in the reconciled list. + + + Parameters + ---------- + completeness + A dictionary containing various tables with the sample name and + completness (ROAD) scores. + interpretability + A dictionary containing various tables with the sample name and + interpretability (Pro. Energy) scores. + + + Returns + ------- + A list containing a table with the sample name, target label, + completeness score (Average ROAD across different ablation thresholds), + interpretability score (Proportional Energy), and the ROAD-Weighted + Proportional Energy score. The ROAD-Weighted Prop. Energy score is + defined as: + + .. math:: + + \\text{ROAD-WeightedPropEng} = \\max(0, \\text{AvgROAD}) \\cdot + \\text{ProportionalEnergy} + """ + retval: list[tuple[str, int, float, float, float]] = [] + + retval = [] + for compl_info, interp_info in zip(completeness, interpretability): + # ensure matching sample name and label + assert compl_info[0] == interp_info[0] + assert compl_info[1] == interp_info[1] + + if len(compl_info) == len(interp_info) == 2: + # there is no data. + continue + + aopc_combined = compl_info[5] + prop_energy = interp_info[2] + road_weighted_prop_energy = max(0, aopc_combined) * prop_energy + + retval.append( + ( + compl_info[0], + compl_info[1], + aopc_combined, + prop_energy, + road_weighted_prop_energy, + ) + ) + + return retval + + +def _make_histogram( + name: str, + values: numpy.typing.NDArray, + xlim: tuple[float, float] | None = None, + title: None | str = None, +) -> matplotlib.figure.Figure: + """Builds an histogram of values. + + Parameters + ---------- + name + Name of the variable to be histogrammed (will appear in the figure) + values + Values to be histogrammed + xlim + A tuple representing the X-axis maximum and minimum to plot. If not + set, then use the bin boundaries. + title + A title to set on the histogram + + + Returns + ------- + A matplotlib figure containing the histogram. + """ + + from matplotlib import pyplot + + fig, ax = pyplot.subplots(1) + ax = typing.cast(matplotlib.figure.Axes, ax) + ax.set_xlabel(name) + ax.set_ylabel("Frequency") + + if title is not None: + ax.set_title(title) + else: + ax.set_title(f"{name} Frequency Histogram") + + n, bins, _ = ax.hist(values, bins="auto", density=True, alpha=0.7) + + if xlim is not None: + ax.spines.bottom.set_bounds(*xlim) + else: + ax.spines.bottom.set_bounds(bins[0], bins[-1]) + + ax.spines.left.set_bounds(0, n.max()) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + + ax.grid(linestyle="--", linewidth=1, color="gray", alpha=0.3) + + # draw median and quartiles + quartile = numpy.percentile(values, [25, 50, 75]) + ax.axvline( + quartile[0], color="green", linestyle="--", label="Q1", alpha=0.5 + ) + ax.axvline(quartile[1], color="red", label="median", alpha=0.5) + ax.axvline( + quartile[2], color="green", linestyle="--", label="Q3", alpha=0.5 + ) + + return fig # type: ignore + + +def summary_table( + summary: dict[SaliencyMapAlgorithm, dict[str, typing.Any]], fmt: str +) -> str: + """Tabulates various summaries into one table. + + Parameters + ---------- + summary + A dictionary mapping saliency algorithm names to the results of + :py:func:`run`. + fmt + One of the formats supported by `python-tabulate + <https://pypi.org/project/tabulate/>`_. + + + Returns + ------- + A string containing the tabulated information. + """ + + headers = [ + "Algorithm", + "AOPC-Combined", + "Prop. Energy", + "ROAD-Normalised", + ] + table = [ + [ + k, + v["aopc-combined"]["quartiles"][50], + v["proportional-energy"]["quartiles"][50], + v["road-normalised-proportional-energy-average"], + ] + for k, v in summary.items() + ] + return tabulate.tabulate(table, headers, tablefmt=fmt, floatfmt=".3f") + + +def _extract_statistics( + algo: SaliencyMapAlgorithm, + data: list[tuple[str, int, float, float, float]], + name: str, + index: int, + dataset: str, + xlim: tuple[float, float] | None = None, +) -> dict[str, typing.Any]: + """Extracts all meaningful statistics from a reconciled statistics set. + + Parameters + ---------- + algo + The algorithm for saliency map estimation that is being analysed. + data + A list of tuples each containing a sample name, target, and values + produced by completeness and interpretability analysis as returned by + :py:func:`_reconcile_metrics`. + name + The name of the variable being analysed + index + Which of the indexes on the tuples containing in ``data`` that should + be extracted. + dataset + The name of the dataset being analysed + xlim + Limits for histogram plotting + + + Returns + ------- + A dictionary containing the following elements: + + * ``values``: A list of values corresponding to the index on the data + * ``mean``: The mean of the value listdir + * ``stdev``: The standard deviation of the value list + * ``quartiles``: The 25%, 50% (median), and 75% quartile of values + * ``plot``: An histogram of values + * ``decreasing_scores``: A list of sample names and labels in + decreasing value. + """ + + val = numpy.array([k[index] for k in data]) + return dict( + values=val, + mean=val.mean(), + stdev=val.std(ddof=1), # unbiased estimator + quartiles={ + 25: numpy.percentile(val, 25), # type: ignore + 50: numpy.median(val), # type: ignore + 75: numpy.percentile(val, 75), # type: ignore + }, + plot=_make_histogram( + name, + val, + xlim=xlim, + title=f"{name} Frequency Histogram ({algo} @ {dataset})", + ), + decreasing_scores=[ + (k[0], k[index]) + for k in sorted(data, key=lambda x: x[index], reverse=True) + ], + ) + + +def run( + saliency_map_algorithm: SaliencyMapAlgorithm, + completeness: dict[str, list], + interpretability: dict[str, list], +) -> dict[str, typing.Any]: + """Evaluates multiple saliency map algorithms and produces summarized + results. + + Parameters + ---------- + saliency_map_algorithm + The algorithm for saliency map estimation that is being analysed. + completeness + A dictionary mapping dataset names to tables with the sample name and + completness (among which Average ROAD) scores. + interpretability + A dictionary mapping dataset names to tables with the sample name and + interpretability (among which Prop. Energy) scores. + + + Returns + ------- + A dictionary with most important statistical values for the main + completeness (AOPC-Combined), interpretability (Prop. Energy), and a + combination of both (ROAD-Weighted Prop. Energy) scores. + """ + + retval: dict = {} + + for dataset, compl_data in completeness.items(): + reconciled = _reconcile_metrics(compl_data, interpretability[dataset]) + d = {} + + d["aopc-combined"] = _extract_statistics( + algo=saliency_map_algorithm, + data=reconciled, + name="AOPC-Combined", + index=2, + dataset=dataset, + ) + d["proportional-energy"] = _extract_statistics( + algo=saliency_map_algorithm, + data=reconciled, + name="Prop.Energy", + index=3, + dataset=dataset, + xlim=(0, 1), + ) + d["road-weighted-proportional-energy"] = _extract_statistics( + algo=saliency_map_algorithm, + data=reconciled, + name="ROAD-weighted-Prop.Energy", + index=4, + dataset=dataset, + ) + + d["road-normalised-proportional-energy-average"] = sum( + retval["road-weighted-proportional-energy"]["val"] + ) / sum([max(0, k) for k in retval["aopc-combined"]["val"]]) + + retval[dataset] = d + + return retval diff --git a/src/ptbench/engine/saliency/generator.py b/src/ptbench/engine/saliency/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fa1d4614541e351af67ab341cf71bc63731b5b --- /dev/null +++ b/src/ptbench/engine/saliency/generator.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import pathlib +import typing + +import lightning.pytorch +import numpy +import torch +import torch.nn +import tqdm + +from ...models.typing import SaliencyMapAlgorithm +from ..device import DeviceManager + +logger = logging.getLogger(__name__) + + +def _create_saliency_map_callable( + algo_type: SaliencyMapAlgorithm, + model: torch.nn.Module, + target_layers: list[torch.nn.Module] | None, + use_cuda: bool, +): + """Creates a class activation map (CAM) instance for a given model.""" + + import pytorch_grad_cam + + match algo_type: + case "gradcam": + return pytorch_grad_cam.GradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "scorecam": + return pytorch_grad_cam.ScoreCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "fullgrad": + return pytorch_grad_cam.FullGrad( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "randomcam": + return pytorch_grad_cam.RandomCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "hirescam": + return pytorch_grad_cam.HiResCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "gradcamelementwise": + return pytorch_grad_cam.GradCAMElementWise( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "gradcam++", "gradcamplusplus": + return pytorch_grad_cam.GradCAMPlusPlus( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "xgradcam": + return pytorch_grad_cam.XGradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "ablationcam": + assert ( + target_layers is not None + ), "AblationCAM cannot have target_layers=None" + return pytorch_grad_cam.AblationCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "eigencam": + return pytorch_grad_cam.EigenCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "eigengradcam": + return pytorch_grad_cam.EigenGradCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case "layercam": + return pytorch_grad_cam.LayerCAM( + model=model, target_layers=target_layers, use_cuda=use_cuda + ) + case _: + raise ValueError( + f"Saliency map algorithm `{algo_type}` is not currently " + f"supported." + ) + + +def _save_saliency_map( + output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor +) -> None: + """Helper function to save a saliency map to disk.""" + + n = pathlib.Path(name) + (output_folder / n.parent).mkdir(parents=True, exist_ok=True) + numpy.save(output_folder / n.with_suffix(".npy"), saliency_map[0]) + + +def run( + model: lightning.pytorch.LightningModule, + datamodule: lightning.pytorch.LightningDataModule, + device_manager: DeviceManager, + saliency_map_algorithm: SaliencyMapAlgorithm, + target_class: typing.Literal["highest", "all"], + positive_only: bool, + output_folder: pathlib.Path, +) -> None: + """Applies saliency mapping techniques on input CXR, outputs pickled + saliency maps directly to disk. + + Parameters + --------- + model + Neural network model (e.g. pasa). + datamodule + The lightning datamodule to iterate on. + device_manager + An internal device representation, to be used for training and + validation. This representation can be converted into a pytorch device + or a torch lightning accelerator setup. + saliency_map_algorithm + The algorithm to use for saliency map estimation. + target_class + (Use only with multi-label models) Which class to target for CAM + calculation. Can be either set to "all" or "highest". "highest" is + default, which means only saliency maps for the class with the highest + activation will be generated. + positive_only + If set, saliency maps will only be generated for positive samples (ie. + label == 1 in a binary classification task). This option is ignored on + a multi-class output model. + output_folder + Where to save all the saliency maps (this path should exist before + this function is called) + """ + from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + + from ...models.densenet import Densenet + from ...models.pasa import Pasa + + if isinstance(model, Pasa): + if saliency_map_algorithm == "fullgrad": + raise ValueError( + "Fullgrad saliency map algorithm is not supported for the " + "Pasa model." + ) + target_layers = [model.fc14] # Last non-1x1 Conv2d layer + elif isinstance(model, Densenet): + target_layers = [ + model.model_ft.features.denseblock4.denselayer16.conv2, # type: ignore + ] + else: + raise TypeError(f"Model of type `{type(model)}` is not yet supported.") + + use_cuda = device_manager.device_type == "cuda" + + # prepares model for evaluation, cast to target device + device = device_manager.torch_device() + model = model.to(device) + model.eval() + + saliency_map_callable = _create_saliency_map_callable( + saliency_map_algorithm, + model, + target_layers, # type: ignore + use_cuda, + ) + + for k, v in datamodule.predict_dataloader().items(): + logger.info( + f"Generating saliency maps for dataset `{k}` via `{saliency_map_algorithm}`..." + ) + + for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): + name = sample[1]["name"][0] + label = sample[1]["label"].item() + image = sample[0].to( + device=device, non_blocking=torch.cuda.is_available() + ) + + # in binary classification systems, negative labels may be skipped + if positive_only and (model.num_classes == 1) and (label == 0): + continue + + # chooses target outputs to generate saliency maps for + if model.num_classes > 1: + if target_class == "all": + # just blindly generate saliency maps for all outputs + # - make one directory for every target output and lay + # images there like in the original dataset. + for output_num in range(model.num_classes): + use_folder = output_folder / str(output_num) + saliency_map = saliency_map_callable( + input_tensor=image, + targets=[ClassifierOutputTarget(output_num)], # type: ignore + ) + _save_saliency_map(use_folder, name, saliency_map) # type: ignore + + else: + # pytorch-grad-cam will figure out the output with the + # highest value and produce a saliency map for it - we + # will save it to disk. + use_folder = output_folder / "highest-output" + saliency_map = saliency_map_callable( + input_tensor=image, + # setting `targets=None` will set target to the + # maximum output index using + # ClassifierOutputTarget(max_output_index) + targets=None, # type: ignore + ) + _save_saliency_map(use_folder, name, saliency_map) # type: ignore + else: + # binary classification model with a single output - just + # lay all cams uniformily like the original dataset + saliency_map = saliency_map_callable( + input_tensor=image, + targets=[ + ClassifierOutputTarget(0), # type: ignore + ], + ) + _save_saliency_map(output_folder, name, saliency_map) # type: ignore diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py new file mode 100644 index 0000000000000000000000000000000000000000..dba6eee30efb2d1d431a4c7ac5274f659ab8a902 --- /dev/null +++ b/src/ptbench/engine/saliency/interpretability.py @@ -0,0 +1,433 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import pathlib +import typing + +import lightning.pytorch +import numpy +import numpy.typing +import skimage.measure +import torch +import torchvision.ops + +from tqdm import tqdm + +from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes + +logger = logging.getLogger(__name__) + + +def _ordered_connected_components( + saliency_map: typing.Sequence[typing.Sequence[float]] + | numpy.typing.NDArray[numpy.double], + threshold: float, +) -> list[numpy.typing.NDArray[numpy.bool_]]: + """Calculates the largest connected components available on a saliency map + and return those as individual masks. + + This implementation is based on [SCORECAM-2020]_: + + 1. Thresholding: The pixel values above ``threshold``% of max value are + kept in the original saliency map. Everything else is set to zero. The + value proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure. + 2. The thresholded saliency map is transformed into a boolean array (ones + are attributed to all elements above the threshold. + 3. We call :py:func:`skimage.metrics.label` to evaluate all connected + components and label those with distinct integers. + 4. We histogram the labels and return one binary mask for each label, + sorted by decreasing size. + + + Parameters + ---------- + saliency_map + Input saliciency map whose connected components will be calculated + from. + threshold + Relative threshold to be used to zero parts of the original saliency + map. A value of 0.2 will zero all values in the saliency map that are + bellow 20% of the maximum value observed in the said map. + + + Returns + ------- + A list of boolean masks, one for each connected component, ordered by + decreasing size. This list may be empty if the input ``saliency_map`` + is all zeroes. + """ + + # thresholds like [SCORECAM-2020]_ + saliency_array = numpy.array(saliency_map) + thresholded_mask = ( + saliency_array >= (threshold * saliency_array.max()) + ).astype(numpy.uint8) + + # avoids an all zeroes mask being processed + if not numpy.any(thresholded_mask): + return [] + + labelled, n = skimage.measure.label(thresholded_mask, return_num=True) # type: ignore + retval = [labelled == k for k in range(1, n + 1)] + + return sorted(retval, key=lambda x: x.sum(), reverse=True) + + +def _extract_bounding_box( + mask: numpy.typing.NDArray[numpy.bool_], +) -> BoundingBox: + """Defines a bounding box surrounding a connected component mask. + + Parameters + ---------- + mask + The connected component mask from whom extract the bounding box. + + + Returns + ------- + A bounding box + """ + x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[ + 0 + ] + return BoundingBox(-1, int(x), int(y), int(x2 - x + 1), int(y2 - y + 1)) + + +def _compute_max_iou_and_ioda( + detected_box: BoundingBox, + gt_bboxes: BoundingBoxes, +) -> tuple[float, float]: + """Will calculate how much of detected area lies in ground truth boxes. + + If there are multiple gt boxes, the detected area will be calculated + for each gt box separately and the gt box with the highest + intersecting part will be used for the calculation. + """ + detected_area = detected_box.area() + if detected_area == 0: + return 0.0, 0.0 + + max_intersection = 0 + max_gt_area = 0 + + for bbox in gt_bboxes: + intersection = bbox.intersection(detected_box) + if intersection > max_intersection: + max_intersection = intersection + max_gt_area = bbox.area() + + if max_gt_area == 0 and max_intersection == 0: + # This case means no intersection was found, even though there are gt boxes + iou, ioda = 0.0, 0.0 + + else: + iou = max_intersection / ( + detected_area + max_gt_area - max_intersection + ) + ioda = max_intersection / detected_area + + return iou, ioda + + +def _get_largest_bounding_boxes( + saliency_map: typing.Sequence[typing.Sequence[float]] + | numpy.typing.NDArray[numpy.double], + n: int, + threshold: float = 0.2, +) -> list[BoundingBox]: + """Returns the N largest connected components as bounding boxes in a + saliency map. + + The return of values is subject to the value of ``threshold`` applied, as + well as on the saliency map itself. The number of objects found is also + affected by those parameters. + + + saliency_map + Input saliciency map whose connected components will be calculated + from. + n + The number of connected components to search for in the saliency map. + Connected components are then translated to bounding-box notation. + threshold + Relative threshold to be used to zero parts of the original saliency + map. A value of 0.2 will zero all values in the saliency map that are + bellow 20% of the maximum value observed in the said map. + + + Returns + ------- + """ + + retval: list[BoundingBox] = [] + + masks = _ordered_connected_components(saliency_map, threshold) + if masks: + retval += [_extract_bounding_box(k) for k in masks[:n]] + + return retval + + +def _compute_simultaneous_iou_and_ioda( + detected_box: BoundingBox, + gt_bboxes: BoundingBoxes, +) -> tuple[float, float]: + """Will calculate how much of detected area lies between ground truth + boxes. + + This means that if there are multiple gt boxes, the detected area + will be compared to them simultaneously (and not to each gt box + separately). + """ + + detected_area = detected_box.area() + if detected_area == 0: + return 0, 0 + + intersection = sum([k.intersection(detected_box) for k in gt_bboxes]) + total_gt_area = sum([k.area() for k in gt_bboxes]) + + iou = intersection / (detected_area + total_gt_area - intersection) + ioda = intersection / detected_area + + return float(iou), float(ioda) + + +def _compute_avg_saliency_focus( + saliency_map: numpy.typing.NDArray[numpy.double], + gt_mask: numpy.typing.NDArray[numpy.bool_], +) -> float: + """Integrates the saliency map over the ground-truth boxes and normalizes + by total bounding-box area. + + This function will integrate (sum) the value of the saliency map over the + ground-truth bounding boxes and normalize it by the total area covered by + all ground-truth bounding boxes. + + + Parameters + ---------- + gt_bboxes + Ground-truth bounding boxes in the format ``(x, y, width, + height)``. + gt_mask + Ground-truth mask containing the bounding boxes of the ground-truth + drawn as ``True`` values. + + + Returns + ------- + A single floating-point number representing the Average saliency focus. + """ + + area = gt_mask.sum() + if area == 0: + return 0.0 + + return numpy.sum(saliency_map * gt_mask) / area + + +def _compute_proportional_energy( + saliency_map: numpy.typing.NDArray[numpy.double], + gt_mask: numpy.typing.NDArray[numpy.bool_], +) -> float: + """Calculates how much activation lies within the ground truth boxes + compared to the total sum of the activations (integral). + + Parameters + ---------- + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + gt_mask + Ground-truth mask containing the bounding boxes of the ground-truth + drawn as ``True`` values. + + + Returns + ------- + A single floating-point number representing the proportional energy. + """ + + denominator = numpy.sum(saliency_map) + + if denominator == 0.0: + return 0.0 + + return float(numpy.sum(saliency_map * gt_mask) / denominator) # type: ignore + + +def _compute_binary_mask( + gt_bboxes: BoundingBoxes, + saliency_map: numpy.typing.NDArray[numpy.double], +) -> numpy.typing.NDArray[numpy.bool_]: + """Computes a binary mask for the saliency map using BoundingBoxes. + + The binary_mask will be ON/True where the gt boxes are located. + + Parameters + ---------- + gt_bboxes + Ground-truth bounding boxes in the format ``(x, y, width, + height)``. + + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + + + Returns + ------- + A numpy array of the same size as saliency_map with + the value False everywhere except at the positions inside + the bounding boxes, which will be True. + """ + + binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_) + for bbox in gt_bboxes: + binary_mask[ + bbox.ymin : bbox.ymin + bbox.height, + bbox.xmin : bbox.xmin + bbox.width, + ] = True + return binary_mask + + +def _process_sample( + gt_bboxes: BoundingBoxes, + saliency_map: numpy.typing.NDArray[numpy.double], +) -> tuple[float, float]: + """Calculates the metrics for a single sample. + + Parameters + ---------- + gt_bboxes + A list of ground-truth bounding boxes. + + + Returns + ------- + A tuple containing the following values: + + * IoU + * IoDA + * Proportional energy + * Average saliency focus + * Largest detected bounding box + """ + + # largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) + # detected_box = ( + # largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) + # ) + # + # # Calculate localization metrics + # iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) + + binary_mask = _compute_binary_mask(gt_bboxes, saliency_map) + + return ( + # iou, + # ioda, + _compute_proportional_energy(saliency_map, binary_mask), + _compute_avg_saliency_focus(saliency_map, binary_mask), + # ( + # detected_box.xmin, + # detected_box.ymin, + # detected_box.width, + # detected_box.height, + # ), + ) + + +def run( + input_folder: pathlib.Path, + target_label: int, + datamodule: lightning.pytorch.LightningDataModule, +) -> dict[str, list[typing.Any]]: + """Applies visualization techniques on input CXR, outputs images with + overlaid heatmaps and csv files with measurements. + + Parameters + --------- + input_folder + Directory in which the saliency maps are stored for a specific + visualization type. + target_label + The label to target for evaluating interpretability metrics. Samples + contining any other label are ignored. + datamodule + The lightning datamodule to iterate on. + + + Returns + ------- + + A dictionary where keys are dataset names in the provide datamodule, + and values are lists containing sample information alongside metrics + calculated: + + * Sample name (str) + * Sample target class (int) + * Proportional energy (float) + * Average saliency focus (float) + """ + + retval: dict[str, list[typing.Any]] = {} + + # TODO: This loads the images from the dataset, but they are not useful at + # this point. Possibly using the contents of ``datamodule.splits`` can + # substantially speed this up. + for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): + logger.info( + f"Estimating interpretability metrics for dataset `{dataset_name}`..." + ) + retval[dataset_name] = [] + + for sample in tqdm( + dataset_loader, desc="batches", leave=False, disable=None + ): + name = str(sample[1]["name"][0]) + label = int(sample[1]["label"].item()) + + if label != target_label: + # we add the entry for dataset completeness, but do not treat + # it + retval[dataset_name].append([name, label]) + continue + + # TODO: This is very specific to the TBX11k system for labelling + # regions of interest. We need to abstract from this to support more + # datasets and other ways to annotate. + bboxes: BoundingBoxes = sample[1].get( + "bounding_boxes", BoundingBoxes() + ) + + if not bboxes: + logger.warning( + f"Sample `{name}` does not contdain bounding-box information. " + f"No localization metrics can be calculated in this case. " + f"Skipping..." + ) + # we add the entry for dataset completeness + retval[dataset_name].append([name, label]) + continue + + # we fully process this entry + retval[dataset_name].append( + [ + name, + label, + *_process_sample( + bboxes[0], + numpy.load( + input_folder + / pathlib.Path(name).with_suffix(".npy") + ), + ), + ] + ) + + return retval diff --git a/src/ptbench/engine/saliency/viewer.py b/src/ptbench/engine/saliency/viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0a7efe1300e069c47189274fc556f177050759 --- /dev/null +++ b/src/ptbench/engine/saliency/viewer.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import os +import pathlib +import typing + +import lightning.pytorch +import matplotlib.pyplot +import numpy +import numpy.typing +import PIL.Image +import PIL.ImageColor +import PIL.ImageDraw +import torchvision.transforms.functional + +from tqdm import tqdm + +from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes + +logger = logging.getLogger(__name__) + + +def _overlay_saliency_map( + image: PIL.Image.Image, + saliencies: numpy.typing.NDArray[numpy.double], + colormap: typing.Literal[ # we accept any "Sequential" colormap from mpl + "viridis", + "plasma", + "inferno", + "magma", + "cividis", + "Greys", + "Purples", + "Blues", + "Greens", + "Oranges", + "Reds", + "YlOrBr", + "YlOrRd", + "OrRd", + "PuRd", + "RdPu", + "BuPu", + "GnBu", + "PuBu", + "YlGnBu", + "PuBuGn", + "BuGn", + "YlGn", + ], + image_weight: float, +) -> PIL.Image.Image: + """Creates an overlayed represention of the saliency map on the original + image. + + This is a slightly modified version of the show_cam_on_image implementation in: + https://github.com/jacobgil/pytorch-grad-cam, but uses matplotlib instead + of opencv. + + + Parameters + ---------- + image + The input imge that will be overlayed with the saliency map + saliencies + The saliency map that will be overlaid on the (raw) image + colormap + The name of the (matplotlib) colormap to be used + image_weight + The final result is ``image_weight * image + (1-image_weight) * + saliency_map``. + + + Returns + ------- + A modified version of the input ``image`` with the overlaid saliency + map. + """ + + image_array = numpy.array(image, dtype=numpy.float32) / 255.0 + + assert image_array.shape[:2] == saliencies.shape, ( + f"The shape of the saliency map ({saliencies.shape}) is different " + f"from the shape of the input image ({image_array.shape[:2]})." + ) + + assert ( + saliencies.max() <= 1 + ), f"The input saliency map should be in the range [0, 1] (max={saliencies.max()})" + + assert ( + image_weight > 0 and image_weight < 1 + ), f"image_weight should be in the range [0, 1], but got {image_weight}" + + heatmap = matplotlib.pyplot.cm.get_cmap(colormap)(saliencies) + + # For pixels where the mask is zero, the original image pixels are being + # used without a mask. + result = numpy.where( + saliencies[..., numpy.newaxis] == 0, + image_array, + (image_weight * image_array) + ((1 - image_weight) * heatmap[:, :, :3]), + ) + + return PIL.Image.fromarray((result * 255).astype(numpy.uint8), "RGB") + + +def _overlay_bounding_box( + image: PIL.Image.Image, + bbox: BoundingBox, + color: str, + width: int, +) -> PIL.Image.Image: + """Draws ground-truth on the input image. + + Parameters + ---------- + image + The input imge that will be overlayed with the saliency map + bbox + The bounding box to draw on the input image + color + The color to use for drawing the bounding box. Any of the colours in + :any:`PIL.ImageColor.colormap` are accepted. + width + The width of the bounding box, in pixels. A larger value creates a + bounding box that is thicker, towards the outside of the boxed area. + + + Returns + ------- + A modified version of the input ``image`` with the ground-truth drawn + on the top. + """ + + draw = PIL.ImageDraw.Draw(image) + draw.rectangle( + (bbox.xmin, bbox.ymin, bbox.xmax, bbox.ymax), + outline=PIL.ImageColor.getrgb(color), + width=width, + ) + return image + + +def _process_sample( + raw_data: numpy.typing.NDArray[numpy.double], + saliencies: numpy.typing.NDArray[numpy.double], + ground_truth: BoundingBoxes, +) -> PIL.Image.Image: + """Generates an overlayed representation of the original sample and + saliency maps. + + Parameters + ---------- + raw_data + The raw data representing the input sample that will be overlayed with + saliency maps and annotations + saliencies + The saliency map recovered from the model, that will be inprinted on + the raw_data + ground_truth + Ground-truth annotations that may be inprinted on the final image + + + Returns + ------- + An image with the original raw data overlayed with the different + elements as selected by the user. + """ + + # we need a colour image to eventually overlay a (coloured) saliency map on + # the top, draw rectangles and other annotations in coulour. So, we force + # it right up front. + retval = torchvision.transforms.functional.to_pil_image(raw_data).convert( + "RGB" + ) + + retval = _overlay_saliency_map( + retval, saliencies, colormap="plasma", image_weight=0.5 + ) + + for k in ground_truth: + retval = _overlay_bounding_box(retval, k, color="green", width=2) + + return retval + + +def run( + datamodule: lightning.pytorch.LightningDataModule, + input_folder: pathlib.Path, + target_label: int, + output_folder: pathlib.Path, + show_groundtruth: bool, + threshold: float, +): + """Overlays saliency maps on CXR to output final images with heatmaps. + + Parameters + ---------- + datamodule + The lightning datamodule to iterate on. + input_folder + Directory in which the saliency maps are stored for a specific + visualization type. + target_label + The label to target for evaluating interpretability metrics. Samples + contining any other label are ignored. + output_folder + Directory in which the resulting visualisations will be saved. + show_groundtruth + If set, inprint ground truth labels over the original image and + saliency maps. + threshold : float + The pixel values above ``threshold``% of max value are kept in the + original saliency map. Everything else is set to zero. The value + proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure. + """ + + for dataset_name, dataset_loader in datamodule.predict_dataloader().items(): + logger.info( + f"Generating visualisations for samples at dataset `{dataset_name}`..." + ) + + for sample in tqdm( + dataset_loader, desc="batches", leave=False, disable=None + ): + name = str(sample[1]["name"][0]) + label = int(sample[1]["label"].item()) + data = sample[0][0] + + if label != target_label: + # no visualisation was generated + continue + + saliencies = numpy.load( + input_folder / pathlib.Path(name).with_suffix(".npy") + ) + saliencies[saliencies < (threshold * saliencies.max())] = 0 + + # TODO: This is very specific to the TBX11k system for labelling + # regions of interest. We need to abstract from this to support more + # datasets and other ways to annotate. + if show_groundtruth: + ground_truth = sample[1].get("bounding_boxes", BoundingBoxes()) + else: + ground_truth = BoundingBoxes() + + # we fully process this entry + image = _process_sample( + data, + saliencies, + ground_truth, + ) + + # Save image + output_file_path = output_folder / pathlib.Path(name).with_suffix( + ".png" + ) + os.makedirs(output_file_path.parent, exist_ok=True) + image.save(output_file_path) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index fccf47b8a0f57870ce8465f3e4c9d555d7f2e996..dee1ad980867e3ff485be1d2b3358e1dac9f9c3b 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -13,6 +13,7 @@ import lightning.pytorch.callbacks import lightning.pytorch.loggers import torch.nn +from ..utils.checkpointer import CHECKPOINT_ALIASES from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from .callbacks import LoggingCallback from .device import DeviceManager @@ -47,13 +48,13 @@ def save_model_summary( summary_path = output_folder / "model-summary.txt" logger.info(f"Saving model summary at {summary_path}...") with summary_path.open("w") as f: - summary = lightning.pytorch.utilities.model_summary.ModelSummary( + summary = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore model, max_depth=-1 ) f.write(str(summary)) return ( summary, - lightning.pytorch.utilities.model_summary.ModelSummary( + lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore model ).total_parameters, ) @@ -99,13 +100,13 @@ def static_information_to_csv( def run( model: lightning.pytorch.LightningModule, datamodule: lightning.pytorch.LightningDataModule, - checkpoint_period: int, + validation_period: int, device_manager: DeviceManager, max_epochs: int, output_folder: pathlib.Path, monitoring_interval: int | float, batch_chunk_count: int, - checkpoint: str | None, + checkpoint: pathlib.Path | None, ): """Fits a CNN model using supervised learning and save it to disk. @@ -122,9 +123,15 @@ def run( datamodule The lightning datamodule to use for training **and** validation - checkpoint_period - Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do - not save intermediary checkpoints. + validation_period + Number of epochs after which validation happens. By default, we run + validation after every training epoch (period=1). You can change this + to make validation more sparse, by increasing the validation period. + Notice that this affects checkpoint saving. While checkpoints are + created after every training step (the last training step always + triggers the overriding of latest checkpoint), and that this process is + independent of validation runs, evaluation of the 'best' model obtained + so far based on those will be influenced by this setting. device_manager An internal device representation, to be used for training and @@ -177,17 +184,22 @@ def run( logging_level=logging.ERROR, ) - checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( - output_folder, - "model_lowest_valid_loss", - save_last=True, + # This checkpointer will operate at the end of every validation epoch + # (which happens at each checkpoint period), it will then save the lowest + # validation loss model observed. It will also save the last trained model + checkpoint_minvalloss_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=output_folder, + filename=CHECKPOINT_ALIASES["best"], + save_last=True, # will (re)create the last trained model, at every iteration monitor="loss/validation", mode="min", save_on_train_epoch_end=True, - every_n_epochs=checkpoint_period, + every_n_epochs=validation_period, # frequency at which it would check the "monitor" + enable_version_counter=False, # no versioning of aliased checkpoints ) - - checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch" + checkpoint_minvalloss_callback.CHECKPOINT_NAME_LAST = CHECKPOINT_ALIASES[ # type: ignore + "periodic" + ] # write static information to a CSV file static_information_to_csv( @@ -204,9 +216,13 @@ def run( max_epochs=max_epochs, accumulate_grad_batches=batch_chunk_count, logger=tensorboard_logger, - check_val_every_n_epoch=1, + check_val_every_n_epoch=validation_period, log_every_n_steps=len(datamodule.train_dataloader()), - callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], + callbacks=[ + LoggingCallback(resource_monitor), + checkpoint_minvalloss_callback, + ], ) - _ = trainer.fit(model, datamodule, ckpt_path=checkpoint) + checkpoint_str = checkpoint if checkpoint is None else str(checkpoint) + _ = trainer.fit(model, datamodule, ckpt_path=checkpoint_str) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index ab66d1aa3ab183e29017e32d8957d0320381acda..5e5b3e7eee3d810aaa4c8dfed31ad3d9fa5414b5 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -77,6 +77,7 @@ class Alexnet(pl.LightningModule): super().__init__() self.name = "alexnet" + self.num_classes = num_classes self.model_transforms = [ torchvision.transforms.Resize(512, antialias=True), @@ -107,7 +108,7 @@ class Alexnet(pl.LightningModule): # Adapt output features self.model_ft.classifier[4] = torch.nn.Linear(4096, 512) - self.model_ft.classifier[6] = torch.nn.Linear(512, num_classes) + self.model_ft.classifier[6] = torch.nn.Linear(512, self.num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index c7def1b5b86a4a4e9e9ab474c7aa93140e0707b7..333edb11e2112db9ef1f1bf37b2d333f5b418de6 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -75,6 +75,7 @@ class Densenet(pl.LightningModule): super().__init__() self.name = "densenet-121" + self.num_classes = num_classes # image is probably large, resize first to get memory usage down self.model_transforms = [ @@ -106,7 +107,7 @@ class Densenet(pl.LightningModule): # Adapt output features self.model_ft.classifier = torch.nn.Linear( - self.model_ft.classifier.in_features, num_classes + self.model_ft.classifier.in_features, self.num_classes ) def forward(self, x): diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 112d7ef6cebacbb514725fe2946c22cd0c452c92..a7b8ae62c0c56bc2c5172058fdd3382c987514a1 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -59,6 +59,9 @@ class Pasa(pl.LightningModule): augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. + + num_classes + Number of outputs (classes) for this model. """ def __init__( @@ -68,10 +71,12 @@ class Pasa(pl.LightningModule): optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], + num_classes: int = 1, ): super().__init__() self.name = "pasa" + self.num_classes = num_classes # image is probably large, resize first to get memory usage down self.model_transforms = [ @@ -146,7 +151,9 @@ class Pasa(pl.LightningModule): self.pool2d = torch.nn.MaxPool2d( (3, 3), (2, 2) ) # Pool after conv. block - self.dense = torch.nn.Linear(80, 1) # Fully connected layer + self.dense = torch.nn.Linear( + 80, self.num_classes + ) # Fully connected layer def forward(self, x): x = self.normalizer(x) # type: ignore diff --git a/src/ptbench/models/typing.py b/src/ptbench/models/typing.py index 3eb9017c1e7bbbf860cf7be67622cdfe5db513df..883811acd439ebfb1fc40a581c18f3b225267d00 100644 --- a/src/ptbench/models/typing.py +++ b/src/ptbench/models/typing.py @@ -25,3 +25,20 @@ MultiClassPredictionSplit: typing.TypeAlias = typing.Mapping[ str, typing.Sequence[MultiClassPrediction] ] """A series of predictions for different database splits.""" + +SaliencyMapAlgorithm: typing.TypeAlias = typing.Literal[ + "ablationcam", + "eigencam", + "eigengradcam", + "fullgrad", + "gradcam", + "gradcamelementwise", + "gradcam++", + "gradcamplusplus", + "hirescam", + "layercam", + "randomcam", + "scorecam", + "xgradcam", +] +"""Supported saliency map algorithms.""" diff --git a/src/ptbench/scripts/cli.py b/src/ptbench/scripts/cli.py index 11cba171fb6857ca3e6fe08e2c355ab12e429f0a..e8f7ba244526e807bd8d63a290c7010b530baafe 100644 --- a/src/ptbench/scripts/cli.py +++ b/src/ptbench/scripts/cli.py @@ -2,20 +2,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import importlib + import click from clapper.click import AliasedGroup -from . import ( - config, - database, - evaluate, - experiment, - predict, - train, - train_analysis, -) - @click.group( cls=AliasedGroup, @@ -26,10 +18,50 @@ def cli(): pass -cli.add_command(config.config) -cli.add_command(database.database) -cli.add_command(evaluate.evaluate) -cli.add_command(experiment.experiment) -cli.add_command(predict.predict) -cli.add_command(train.train) -cli.add_command(train_analysis.train_analysis) +cli.add_command(importlib.import_module("..config", package=__name__).config) +cli.add_command( + importlib.import_module("..database", package=__name__).database +) +cli.add_command( + importlib.import_module("..evaluate", package=__name__).evaluate +) +cli.add_command( + importlib.import_module("..experiment", package=__name__).experiment +) +cli.add_command(importlib.import_module("..predict", package=__name__).predict) +cli.add_command(importlib.import_module("..train", package=__name__).train) +cli.add_command( + importlib.import_module("..train_analysis", package=__name__).train_analysis +) + + +@click.group( + cls=AliasedGroup, + context_settings=dict(help_option_names=["-?", "-h", "--help"]), +) +def saliency(): + """Sub-commands to generate, evaluate and view saliency maps.""" + pass + + +cli.add_command(saliency) + +saliency.add_command( + importlib.import_module("..saliency.generate", package=__name__).generate +) +saliency.add_command( + importlib.import_module( + "..saliency.completeness", package=__name__ + ).completeness +) +saliency.add_command( + importlib.import_module( + "..saliency.interpretability", package=__name__ + ).interpretability +) +saliency.add_command( + importlib.import_module("..saliency.evaluate", package=__name__).evaluate +) +saliency.add_command( + importlib.import_module("..saliency.view", package=__name__).view +) diff --git a/src/ptbench/scripts/click.py b/src/ptbench/scripts/click.py index 39cf96f962fced761bb68c12a593a24be19e9fa5..07dfe697c5f0e35356f75958c949c97aaef36aa8 100644 --- a/src/ptbench/scripts/click.py +++ b/src/ptbench/scripts/click.py @@ -27,4 +27,4 @@ class ConfigCommand(_BaseConfigCommand): if self.epilog: formatter.write_paragraph() for line in self.epilog.split("\n"): - formatter.write_text(line) + formatter.write(line + "\n") diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index 2e4fe626fbba296b030c82effc53390924096dbc..72f6b3061feaa012788ad10bd2fc12b18837c74d 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -49,7 +49,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "--output-folder", "-o", help="Path where to store the analysis result (created if does not exist)", - required=True, + required=False, default="results", type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), cls=ResourceOption, diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py index 55c24260f5d32271ed74705b77b07b6a311bb39a..f75569513efaada033140c1f78c58cf79a19a72a 100644 --- a/src/ptbench/scripts/experiment.py +++ b/src/ptbench/scripts/experiment.py @@ -42,7 +42,7 @@ def experiment( batch_chunk_count, drop_incomplete_batch, datamodule, - checkpoint_period, + validation_period, device, cache_samples, seed, @@ -84,7 +84,7 @@ def experiment( batch_chunk_count=batch_chunk_count, drop_incomplete_batch=drop_incomplete_batch, datamodule=datamodule, - checkpoint_period=checkpoint_period, + validation_period=validation_period, device=device, cache_samples=cache_samples, seed=seed, @@ -113,12 +113,6 @@ def experiment( from .predict import predict - # preferably, we use the best model on the validation set - # otherwise, we get the last saved model - model_file = train_output_folder / "model_lowest_valid_loss.ckpt" - if not model_file.exists(): - model_file = train_output_folder / "model_final_epoch.ckpt" - predictions_output = output_folder / "predictions.json" ctx.invoke( @@ -127,7 +121,7 @@ def experiment( model=model, datamodule=datamodule, device=device, - weight=model_file, + weight=train_output_folder, batch_size=batch_size, parallel=parallel, ) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 91b66d725c83ed9802d7ff9f03c2e796d670e546..6b77f3d3d9d580df59348a45ca9841b19b9ba0a6 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -23,13 +23,13 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") .. code:: sh - ptbench predict -vv pasa montgomery --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --weight=path/to/model-at-lowest-validation-loss.ckpt --output=path/to/predictions.json 2. Enables multi-processing data loading with 6 processes: .. code:: sh - ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model.ckpt --output=path/to/predictions.json + ptbench predict -vv pasa montgomery --parallel=6 --weight=path/to/model-at-lowest-validation-loss.ckpt --output=path/to/predictions.json """, ) @@ -88,10 +88,18 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") "--weight", "-w", help="""Path or URL to pretrained model file (`.ckpt` extension), - corresponding to the architecture set with `--model`.""", + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", required=True, cls=ResourceOption, - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), ) @click.option( "--parallel", @@ -125,6 +133,7 @@ def predict( from ..engine.device import DeviceManager from ..engine.predictor import run + from ..utils.checkpointer import get_checkpoint_to_run_inference datamodule.set_chunk_size(batch_size, 1) datamodule.parallel = parallel @@ -133,6 +142,9 @@ def predict( datamodule.prepare_data() datamodule.setup(stage="predict") + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) + logger.info(f"Loading checkpoint from `{weight}`...") model = type(model).load_from_checkpoint(weight, strict=False) diff --git a/src/ptbench/scripts/saliency/__init__.py b/src/ptbench/scripts/saliency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ptbench/scripts/saliency/completeness.py b/src/ptbench/scripts/saliency/completeness.py new file mode 100644 index 0000000000000000000000000000000000000000..eb305b6847aed6618f88ebdaba8b032b8e6f133c --- /dev/null +++ b/src/ptbench/scripts/saliency/completeness.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib +import typing + +import click + +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup + +from ...models.typing import SaliencyMapAlgorithm +from ..click import ConfigCommand + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Calculates the ROAD scores for an existing dataset configuration and stores them in .csv files: + + .. code:: sh + + ptbench saliency completeness -vv pasa tbx11k-v1-healthy-vs-atb --device="cuda" --weight=path/to/model-at-lowest-validation-loss.ckpt --output-json=path/to/completeness-scores.json + +""", +) +@click.option( + "--model", + "-m", + help="""A lightining module instance implementing the network architecture + (not the weights, necessarily) to be used for inference. Currently, only + supports pasa and densenet models.""", + required=True, + cls=ResourceOption, +) +@click.option( + "--datamodule", + "-d", + help="""A lighting data module that will be asked for prediction data + loaders. Typically, this includes all configured splits in a datamodule, + however this is not a requirement. A datamodule that returns a single + dataloader for prediction (wrapped in a dictionary) is acceptable.""", + required=True, + cls=ResourceOption, +) +@click.option( + "--output-json", + "-o", + help="""Path where to store the output JSON file containing all + measures.""", + required=True, + type=click.Path( + file_okay=True, + dir_okay=False, + path_type=pathlib.Path, + ), + default="saliency-interpretability.json", + cls=ResourceOption, +) +@click.option( + "--device", + "-x", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, +) +@click.option( + "--cache-samples/--no-cache-samples", + help="If set to True, loads the sample into memory, " + "otherwise loads them at runtime.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, +) +@click.option( + "--weight", + "-w", + help="""Path or URL to pretrained model file (`.ckpt` extension), + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", + required=True, + cls=ResourceOption, + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), +) +@click.option( + "--parallel", + "-P", + help="""Use multiprocessing for data loading processing: if set to -1 + (default), disables multiprocessing. Set to 0 to enable as many data + processing instances as processing cores available in the system. Set to + >= 1 to enable that many multiprocessing instances. Note that if you + activate this option, then you must use --device=cpu, as using a GPU + concurrently is not supported.""", + type=click.IntRange(min=-1), + show_default=True, + required=True, + default=-1, + cls=ResourceOption, +) +@click.option( + "--saliency-map-algorithm", + "-s", + help="""Saliency map algorithm to be used.""", + type=click.Choice( + typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + ), + default="gradcam", + show_default=True, + cls=ResourceOption, +) +@click.option( + "--target-class", + "-C", + help="""This option should only be used with multiclass models. It + defines the class to target for saliency estimation. Can be either set to + "all" or "highest". "highest" (the default), means only saliency maps for + the class with the highest activation will be generated.""", + required=False, + type=click.Choice( + ["highest", "all"], + case_sensitive=False, + ), + default="highest", + cls=ResourceOption, +) +@click.option( + "--positive-only/--no-positive-only", + "-z/-Z", + help="""If set, and the model chosen has a single output (binary), then + saliency maps will only be generated for samples of the positive class. + This option has no effect for multiclass models.""", + default=False, + cls=ResourceOption, +) +@click.option( + "--percentile", + "-e", + help="""One or more percentiles (percent x100) integer values indicating + the proportion of pixels to perturb in the original image to calculate both + MoRF and LeRF scores.""", + multiple=True, + default=[20, 40, 60, 80], + show_default=True, + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def completeness( + model, + datamodule, + output_json, + device, + cache_samples, + weight, + parallel, + saliency_map_algorithm, + target_class, + positive_only, + percentile, + **_, +) -> None: + """Evaluates saliency map algorithm completeness using RemOve And Debias + (ROAD). + + For the selected saliency map algorithm, evaluates the completeness of + explanations using the RemOve And Debias (ROAD) algorithm. The ROAD + algorithm was first described at [ROAD-2022]_. It estimates explainability + (in the completeness sense) of saliency mapping algorithms by substituting + relevant pixels in the input image by a local average, re-running + prediction on the altered image, and measuring changes in the output + classification score when said perturbations are in place. By substituting + most or least relevant pixels with surrounding averages, the ROAD algorithm + estimates the importance of such elements in the produced saliency map. As + 2023, this measurement technique is considered to be one of the + state-of-the-art metrics of explainability. + + This program outputs a JSON file containing the ROAD evaluations (using + most-relevant-first, or MoRF, and least-relevant-first, or LeRF for each + sample in the datamodule. Values for MoRF and LeRF represent averages by + removing 20, 40, 60 and 80% of most or least relevant pixels respectively + from the image, and averaging results for all these percentiles. + + .. note:: + + This application is relatively slow when processing a large datamodule + with many (positive) samples. + """ + import json + + from ...engine.device import DeviceManager + from ...engine.saliency.completeness import run + from ...utils.checkpointer import get_checkpoint_to_run_inference + + if device in ("cuda", "mps") and (parallel == 0 or parallel > 1): + raise RuntimeError( + f"The number of multiprocessing instances is set to {parallel} and " + f"you asked to use a GPU (device = `{device}`). The currently " + f"implementation can only handle a single GPU. Either disable GPU " + f"utilisation or set the number of multiprocessing instances to " + f"one, or disable multiprocessing entirely (ie. set it to -1)." + ) + + device_manager = DeviceManager(device) + + # batch_size must be == 1 for now (underlying code is NOT prepared to + # treat multiple samples at once). + datamodule.set_chunk_size(1, 1) + datamodule.cache_samples = cache_samples + datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="predict") + + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) + + logger.info(f"Loading checkpoint from `{weight}`...") + model = type(model).load_from_checkpoint(weight, strict=False) + + logger.info( + f"Evaluating RemOve And Debias (ROAD) average scores for " + f"algorithm `{saliency_map_algorithm}` with percentiles " + f"`{', '.join([str(k) for k in percentile])}`..." + ) + results = run( + model=model, + datamodule=datamodule, + device_manager=device_manager, + saliency_map_algorithm=saliency_map_algorithm, + target_class=target_class, + positive_only=positive_only, + percentiles=percentile, + parallel=parallel, + ) + + output_json.parent.mkdir(parents=True, exist_ok=True) + with output_json.open("w") as f: + logger.info(f"Saving output file to `{str(output_json)}`...") + json.dump(results, f, indent=2) diff --git a/src/ptbench/scripts/saliency/evaluate.py b/src/ptbench/scripts/saliency/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3acc47ac99182982d9f9983480b0295da9a573 --- /dev/null +++ b/src/ptbench/scripts/saliency/evaluate.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib +import typing + +import click + +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup + +from ...models.typing import SaliencyMapAlgorithm +from ..click import ConfigCommand + +# avoids X11/graphical desktop requirement when creating plots +__import__("matplotlib").use("agg") + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Tabulates and generates plots for two saliency map algorithms: + + .. code:: sh + + ptbench saliency evaluate -vv -e gradcam path/to/gradcam-completeness.json path/to/gradcam-interpretability.json -e gradcam++ path/to/gradcam++-completeness.json path/to/gradcam++-interpretability.json +""", +) +@click.option( + "--entry", + "-e", + required=True, + multiple=True, + help=f"ENTRY is a triplet containing the algorithm name, the path to the " + f"scores issued from the completness analysis (``ptbench " + f"saliency-completness``) and scores issued from the interpretability " + f"analysis (``ptbench saliency-interpretability``), both in JSON format. " + f"Paths to score files must exist before the program is called. Valid values " + f"for saliency map algorithms are " + f"{'|'.join(typing.get_args(SaliencyMapAlgorithm))}", + type=( + click.Choice( + typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + ), + click.Path( + exists=True, + file_okay=True, + dir_okay=False, + path_type=pathlib.Path, + ), + click.Path( + exists=True, + file_okay=True, + dir_okay=False, + path_type=pathlib.Path, + ), + ), + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the analysis result (created if does not exist)", + required=False, + default="results", + type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path), + cls=ResourceOption, +) +@verbosity_option(logger=logger, expose_value=False) +def evaluate( + entry, + output_folder, + **_, # ignored +) -> None: + """Calculates summary statistics for a saliency map algorithm.""" + import json + + from matplotlib.backends.backend_pdf import PdfPages + + from ...engine.saliency.evaluator import run, summary_table + + summary = { + algo: run(algo, json.load(complet.open()), json.load(interp.open())) + for algo, complet, interp in entry + } + table = summary_table(summary, "rst") + click.echo(summary) + + if output_folder is not None: + output_folder.mkdir(parents=True, exist_ok=True) + + table_path = output_folder / "summary.rst" + + logger.info(f"Saving summary table at `{table_path}`...") + with table_path.open("w") as f: + f.write(table) + + figure_path = output_folder / "plots.pdf" + logger.info(f"Saving figures at `{figure_path}`...") + + with PdfPages(figure_path) as pdf: + for dataset in summary.keys(): + pdf.savefig(summary[dataset]["aopc-combined"]["plot"]) + pdf.savefig(summary[dataset]["proportional-energy"]["plot"]) + pdf.savefig( + summary[dataset]["road-weighted-proportional-energy"][ + "plot" + ] + ) diff --git a/src/ptbench/scripts/saliency/generate.py b/src/ptbench/scripts/saliency/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b0327728c53aec36cfa08770a2bded3779e8b8ce --- /dev/null +++ b/src/ptbench/scripts/saliency/generate.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib +import typing + +import click + +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup + +from ...models.typing import SaliencyMapAlgorithm +from ..click import ConfigCommand + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Generates saliency maps for all prediction dataloaders on a datamodule, + using a pre-trained DenseNet model, and saves them as numpy-pickeled + objects on the output directory: + + .. code:: sh + + ptbench saliency generate -vv densenet tbx11k-v1-healthy-vs-atb --weight=path/to/model-at-lowest-validation-loss.ckpt --output-folder=path/to/output + +""", +) +@click.option( + "--model", + "-m", + help="""A lightining module instance implementing the network architecture + (not the weights, necessarily) to be used for inference. Currently, only + supports pasa and densenet models.""", + required=True, + cls=ResourceOption, +) +@click.option( + "--datamodule", + "-d", + help="""A lighting data module that will be asked for prediction data + loaders. Typically, this includes all configured splits in a datamodule, + however this is not a requirement. A datamodule that returns a single + dataloader for prediction (wrapped in a dictionary) is acceptable.""", + required=True, + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store saliency maps (created if does not exist)", + required=True, + type=click.Path( + exists=False, + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="saliency-maps", + cls=ResourceOption, +) +@click.option( + "--device", + "-x", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', + show_default=True, + required=True, + default="cpu", + cls=ResourceOption, +) +@click.option( + "--cache-samples/--no-cache-samples", + help="If set to True, loads the sample into memory, " + "otherwise loads them at runtime.", + required=True, + show_default=True, + default=False, + cls=ResourceOption, +) +@click.option( + "--weight", + "-w", + help="""Path or URL to pretrained model file (`.ckpt` extension), + corresponding to the architecture set with `--model`. Optionally, you may + also pass a directory containing the result of a training session, in which + case either the best (lowest validation) or latest model will be loaded.""", + required=True, + cls=ResourceOption, + type=click.Path( + exists=True, + file_okay=True, + dir_okay=True, + readable=True, + path_type=pathlib.Path, + ), +) +@click.option( + "--parallel", + "-P", + help="""Use multiprocessing for data loading: if set to -1 (default), + disables multiprocessing data loading. Set to 0 to enable as many data + loading instances as processing cores as available in the system. Set to + >= 1 to enable that many multiprocessing instances for data loading.""", + type=click.IntRange(min=-1), + show_default=True, + required=True, + default=-1, + cls=ResourceOption, +) +@click.option( + "--saliency-map-algorithm", + "-s", + help="""Saliency map algorithm to be used.""", + type=click.Choice( + typing.get_args(SaliencyMapAlgorithm), case_sensitive=False + ), + default="gradcam", + show_default=True, + cls=ResourceOption, +) +@click.option( + "--target-class", + "-C", + help="""This option should only be used with multiclass models. It + defines the class to target for saliency estimation. Can be either set to + "all" or "highest". "highest" (the default), means only saliency maps for + the class with the highest activation will be generated.""", + required=False, + type=click.Choice( + ["highest", "all"], + case_sensitive=False, + ), + default="highest", + cls=ResourceOption, +) +@click.option( + "--positive-only/--no-positive-only", + "-z/-Z", + help="""If set, and the model chosen has a single output (binary), then + saliency maps will only be generated for samples of the positive class. + This option has no effect for multiclass models.""", + default=False, + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def generate( + model, + datamodule, + output_folder, + device, + cache_samples, + weight, + parallel, + saliency_map_algorithm, + target_class, + positive_only, + **_, +) -> None: + """Generates saliency maps for locations on input images that affected the + prediction. + + The quality of saliency information depends on the saliency map + algorithm and trained model. + """ + + from ...engine.device import DeviceManager + from ...engine.saliency.generator import run + from ...utils.checkpointer import get_checkpoint_to_run_inference + + logger.info(f"Output folder: {output_folder}") + output_folder.mkdir(parents=True, exist_ok=True) + + device_manager = DeviceManager(device) + + # batch_size must be == 1 for now (underlying code is NOT prepared to + # treat multiple samples at once). + datamodule.set_chunk_size(1, 1) + datamodule.cache_samples = cache_samples + datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="predict") + + if weight.is_dir(): + weight = get_checkpoint_to_run_inference(weight) + + logger.info(f"Loading checkpoint from `{weight}`...") + model = type(model).load_from_checkpoint(weight, strict=False) + + run( + model=model, + datamodule=datamodule, + device_manager=device_manager, + saliency_map_algorithm=saliency_map_algorithm, + target_class=target_class, + positive_only=positive_only, + output_folder=output_folder, + ) diff --git a/src/ptbench/scripts/saliency/interpretability.py b/src/ptbench/scripts/saliency/interpretability.py new file mode 100644 index 0000000000000000000000000000000000000000..6982b8700c66ea42e374ace4ff2dad4ed16ca8c9 --- /dev/null +++ b/src/ptbench/scripts/saliency/interpretability.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib + +import click + +from clapper.click import ResourceOption, verbosity_option +from clapper.logging import setup + +from ..click import ConfigCommand + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Evaluate the generated saliency maps for their localization performance: + + .. code:: sh + + ptbench saliency interpretability -vv tbx11k-v1-healthy-vs-atb --input-folder=parent-folder/saliencies/ --output-json=path/to/interpretability-scores.json + +""", +) +@click.option( + "--datamodule", + "-d", + help="""A lighting data module that will be asked for prediction data + loaders. Typically, this includes all configured splits in a datamodule, + however this is not a requirement. A datamodule that returns a single + dataloader for prediction (wrapped in a dictionary) is acceptable.""", + required=True, + cls=ResourceOption, +) +@click.option( + "--input-folder", + "-i", + help="""Path where to load saliency maps from. You can generate saliency + maps with ``ptbench generate-saliencymaps``.""", + required=True, + type=click.Path( + exists=True, + file_okay=False, + dir_okay=True, + path_type=pathlib.Path, + ), + default="saliency-maps", + cls=ResourceOption, +) +@click.option( + "--target-label", + "-t", + help="""The target label that will be analysed. It must match the target + label that was used to generate the saliency maps provided with option + ``--input-folder``. Samples with all other labels are ignored.""", + required=True, + type=click.INT, + default=1, + cls=ResourceOption, +) +@click.option( + "--output-json", + "-o", + help="""Path where to store the output JSON file containing all + measures.""", + required=True, + type=click.Path( + file_okay=True, + dir_okay=False, + path_type=pathlib.Path, + ), + default="saliency-interpretability.json", + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def interpretability( + datamodule, + input_folder, + target_label, + output_json, + **_, +) -> None: + """Evaluates saliency map agreement with annotations (human + interpretability). + + The evaluation happens by comparing saliency maps with ground-truth + provided by any other means (typically following a manual annotation + procedure). + + .. note:: + + For obvious reasons, this evaluation is limited to datasets that + contain built-in annotations which corroborate classification. + + + As a result of the evaluation, this application creates a single JSON file + that resembles the original datamodule, with added information containing + the following measures, for each sample: + + * Proportional Energy: A measure that compares (UNthresholed) saliency maps + with annotations (based on [SCORECAM-2020]_). It estimates how much + activation lies within the ground truth boxes compared to the total sum + of the activations. + * Average Saliency Focus: estimates how much of the ground truth bounding + boxes area is covered by the activations. It is similar to the + proportional energy measure in the sense it does not need explicit + thresholding. + """ + + import json + + from ...engine.saliency.interpretability import run + + datamodule.model_transforms = [] + datamodule.prepare_data() + datamodule.setup(stage="predict") + + results = run(input_folder, target_label, datamodule) + + with output_json.open("w") as f: + logger.info(f"Saving output file to `{str(output_json)}`...") + json.dump(results, f, indent=2) diff --git a/src/ptbench/scripts/saliency/view.py b/src/ptbench/scripts/saliency/view.py new file mode 100644 index 0000000000000000000000000000000000000000..af08086f7dac2be643cd0728f2e29bf5d3929656 --- /dev/null +++ b/src/ptbench/scripts/saliency/view.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os +import pathlib + +import click + +from clapper.click import ConfigCommand, ResourceOption, verbosity_option +from clapper.logging import setup + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +@click.command( + entry_point_group="ptbench.config", + cls=ConfigCommand, + epilog="""Examples: + +1. Generates visualizations in form of heatmaps from existing saliency maps for a dataset configuration: + + .. code:: sh + + ptbench saliency view -vv pasa tbx11k-v1-healthy-vs-atb --input-folder=parent_folder/gradcam/ --output-folder=path/to/visualizations +""", +) +@click.option( + "--model", + "-m", + help="A lightining module instance implementing the network to be used for applying the necessary data transformations.", + required=True, + cls=ResourceOption, +) +@click.option( + "--datamodule", + "-d", + help="A lighting data module containing the training, validation and test sets.", + required=True, + cls=ResourceOption, +) +@click.option( + "--input-folder", + "-i", + help="Path to the folder containing the saliency maps for a specific visualization type.", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="visualizations", + cls=ResourceOption, +) +@click.option( + "--output-folder", + "-o", + help="Path where to store the ROAD scores (created if does not exist)", + required=True, + type=click.Path( + file_okay=False, + dir_okay=True, + writable=True, + path_type=pathlib.Path, + ), + default="visualizations", + cls=ResourceOption, +) +@click.option( + "--show-groundtruth/--no-show-groundtruth", + "-G/-g", + help="""If set, visualizations for ground truth labels will be generated. + Only works for datasets with bounding boxes.""", + is_flag=True, + default=False, + cls=ResourceOption, +) +@click.option( + "--threshold", + "-t", + help="""The pixel values above ``threshold`` % of max value are kept in the + original saliency map. Everything else is set to zero. The value proposed + on [SCORECAM-2020]_ is 0.2. Use this value if unsure.""", + show_default=True, + required=True, + default=0.2, + type=click.FloatRange(min=0, max=1), + cls=ResourceOption, +) +@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) +def view( + model, + datamodule, + input_folder, + output_folder, + show_groundtruth, + threshold, + **_, +) -> None: + """Generates heatmaps for input CXRs based on existing saliency maps.""" + + from ...engine.saliency.viewer import run + from ..utils import save_sh_command + + assert ( + input_folder != output_folder + ), "Output folder must not be the same as the input folder." + + assert not str(output_folder).startswith( + str(input_folder) + ), "Output folder must not be a subdirectory of the input folder." + + logger.info(f"Output folder: {output_folder}") + os.makedirs(output_folder, exist_ok=True) + + save_sh_command(output_folder / "command.sh") + + datamodule.set_chunk_size(1, 1) + datamodule.drop_incomplete_batch = False + # datamodule.cache_samples = cache_samples + # datamodule.parallel = parallel + datamodule.model_transforms = model.model_transforms + + datamodule.prepare_data() + datamodule.setup(stage="predict") + + run( + datamodule=datamodule, + input_folder=input_folder, + target_label=1, + output_folder=output_folder, + show_groundtruth=show_groundtruth, + threshold=threshold, + ) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 8dcd832a6e0a895b0d1bd8b3be5b9a8482ec37d0..b2663208b3c3bbb2ce9f3c895a49076cb6daf9ba 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -16,7 +16,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") def reusable_options(f): - """Options that can be re-used by top-level scripts (i.e. ``experiment```). + """Options that can be re-used by top-level scripts (i.e. ``experiment``). This decorator equips the target function ``f`` with all (reusable) ``train`` script options. @@ -125,16 +125,21 @@ def reusable_options(f): cls=ResourceOption, ) @click.option( - "--checkpoint-period", + "--validation-period", "-p", - help="""Number of epochs after which a checkpoint is saved. A value of - zero will disable check-pointing. If checkpointing is enabled and - training stops, it is automatically resumed from the last saved - checkpoint if training is restarted with the same configuration.""", + help="""Number of epochs after which validation happens. By default, + we run validation after every training epoch (period=1). You can + change this to make validation more sparse, by increasing the + validation period. Notice that this affects checkpoint saving. While + checkpoints are created after every training step (the last training + step always triggers the overriding of latest checkpoint), and that + this process is independent of validation runs, evaluation of the + 'best' model obtained so far based on those will be influenced by this + setting.""", show_default=True, - required=False, - default=None, - type=click.IntRange(min=0), + required=True, + default=1, + type=click.IntRange(min=1), cls=ResourceOption, ) @click.option( @@ -170,8 +175,9 @@ def reusable_options(f): "-P", help="""Use multiprocessing for data loading: if set to -1 (default), disables multiprocessing data loading. Set to 0 to enable as many data - loading instances as processing cores as available in the system. Set to - >= 1 to enable that many multiprocessing instances for data loading.""", + loading instances as processing cores available in the system. Set to + >= 1 to enable that many multiprocessing instances for data + loading.""", type=click.IntRange(min=-1), show_default=True, required=True, @@ -182,27 +188,19 @@ def reusable_options(f): "--monitoring-interval", "-I", help="""Time between checks for the use of resources during each training - epoch. An interval of 5 seconds, for example, will lead to CPU and GPU - resources being probed every 5 seconds during each training epoch. - Values registered in the training logs correspond to averages (or maxima) - observed through possibly many probes in each epoch. Notice that setting a - very small value may cause the probing process to become extremely busy, - potentially biasing the overall perception of resource usage.""", + epoch, in seconds. An interval of 5 seconds, for example, will lead to + CPU and GPU resources being probed every 5 seconds during each training + epoch. Values registered in the training logs correspond to averages + (or maxima) observed through possibly many probes in each epoch. + Notice that setting a very small value may cause the probing process to + become extremely busy, potentially biasing the overall perception of + resource usage.""", type=click.FloatRange(min=0.1), show_default=True, required=True, default=5.0, cls=ResourceOption, ) - @click.option( - "--resume-from", - help="""Which checkpoint to resume training from. If set, can be one of - `best`, `last`, or a path to a model checkpoint.""", - type=click.STRING, - required=False, - default=None, - cls=ResourceOption, - ) @click.option( "--balance-classes/--no-balance-classes", "-B/-N", @@ -243,39 +241,49 @@ def train( batch_chunk_count, drop_incomplete_batch, datamodule, - checkpoint_period, + validation_period, device, cache_samples, seed, parallel, monitoring_interval, - resume_from, balance_classes, **_, ) -> None: """Trains an CNN to perform image classification. Training is performed for a configurable number of epochs, and - generates at least a final_model.pth. It may also generate a number - of intermediate checkpoints. Checkpoints are model files (.pth - files) that are stored during the training and useful to resume the - procedure in case it stops abruptly. + generates at least a final_model.ckpt. It may also generate a + number of intermediate checkpoints. Checkpoints are model files + (.ckpt files) that are stored during the training and useful to + resume the procedure in case it stops abruptly. """ + import os + import torch from lightning.pytorch import seed_everything from ..engine.device import DeviceManager from ..engine.trainer import run - from ..utils.checkpointer import get_checkpoint + from ..utils.checkpointer import get_checkpoint_to_resume_training from .utils import save_sh_command + checkpoint_file = None + if os.path.isdir(output_folder): + try: + checkpoint_file = get_checkpoint_to_resume_training(output_folder) + except FileNotFoundError: + logger.info( + f"Folder {output_folder} already exists, but I did not" + f" find any usable checkpoint file to resume training" + f" from. Starting from scratch..." + ) + save_sh_command(output_folder / "command.sh") seed_everything(seed) - checkpoint_file = get_checkpoint(output_folder, resume_from) - # reset datamodule with user configurable options datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.drop_incomplete_batch = drop_incomplete_batch @@ -306,25 +314,31 @@ def train( arguments["epoch"] = 0 if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): - # Sets the model normalizer with the unaugmented-train-subset. - # this call may be a NOOP, if the model was pre-trained and expects - # different weights for the normalisation layer. + # Sets the model normalizer with the unaugmented-train-subset if we are + # starting from scratch and/or the model does not contain its own + # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This + # call may be a NOOP, if the model comes from outside this framework, + # and expects different weights for the normalisation layer. if hasattr(model, "set_normalizer"): model.set_normalizer(datamodule.unshuffled_train_dataloader()) else: logger.warning( - f"Model {model.name} has no 'set_normalizer' method. Skipping." + f"Model {model.name} has no `set_normalizer` method. " + "Skipping normalization setup (unsupported external model)." ) else: # Normalizer will be loaded during model.on_load_checkpoint checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint["epoch"] - logger.info(f"Resuming from epoch {start_epoch}...") + logger.info( + f"Resuming from epoch {start_epoch} " + f"(checkpoint file: `{str(checkpoint_file)}`)..." + ) run( model=model, datamodule=datamodule, - checkpoint_period=checkpoint_period, + validation_period=validation_period, device_manager=DeviceManager(device), max_epochs=epochs, output_folder=output_folder, diff --git a/src/ptbench/utils/checkpointer.py b/src/ptbench/utils/checkpointer.py index 8f6685b843fd3e6ff00f1d5cc6b5fa6e0dc14e53..543f7c7008e2aef6403babdc55cd4f391d36770b 100644 --- a/src/ptbench/utils/checkpointer.py +++ b/src/ptbench/utils/checkpointer.py @@ -4,76 +4,146 @@ import logging import pathlib +import re import typing logger = logging.getLogger(__name__) -def get_checkpoint( - output_folder: pathlib.Path, - resume_from: typing.Literal["last", "best"] | str | None, -) -> str | None: - """Gets a checkpoint file. +CHECKPOINT_ALIASES = { + "best": "model-at-lowest-validation-loss-{epoch}", + "periodic": "model-at-{epoch}", +} +"""Standard paths where checkpoints may be (if produced with this +framework).""" - Can return the best or last checkpoint, or a checkpoint at a specific path. - Ensures the checkpoint exists, raising an error if it is not the case. +CHECKPOINT_EXTENSION = ".ckpt" - If ``resume_from`` is ``None``, checks the output directory if a "last" - checkpoint file already exists and returns it. If no checkpoint is found, - returns ``None``. - ``resume_from`` can also be a path to an existing checkpoint file. In this - case, we check it and return if it exists. +def _get_checkpoint_from_alias( + path: pathlib.Path, + alias: typing.Literal["best", "periodic"], +) -> pathlib.Path: + """Gets an existing checkpoint file path. + + This function can search for names matching the checkpoint alias "stem" + (ie. the prefix), and then assumes a dash "-" and a number follows that + prefix before the expected file extension. The number is parsed and + considred to be an epoch number. The latest file (the file containing the + highest epoch number) is returned. + + If only one file is present matching the alias characteristics, then it is + returned. Parameters ---------- - output_folder - Folder in which checkpoints are stored. - resume_from - Which model to get. Can be one of "best", "last", or a path to a checkpoint. - If ``None``, gets the last checkpoint if it exists, otherwise returns - ``None`` (signal to start from scratch). + path + Folder in which may contain checkpoint + alias + Can be one of "best" or "periodic". Returns ------- - Path to the requested checkpoint (as a plain string) or ``None`` (start - from scratch). + Path to the requested checkpoint, or ``None``, if no checkpoint file + matching specifications is found on the provided path. Raises ------ FileNotFoundError - In case a required file cannot be found. + In case it cannot find any file on the provided path matching the given + specifications. """ - # standard paths where checkpoints may be (if produced with this framework) - last_path = output_folder / "model_final_epoch.ckpt" - best_path = output_folder / "model_lowest_valid_loss.ckpt" - - if resume_from in ("last", "best"): - use_file = last_path if resume_from == "last" else best_path - if use_file.is_file(): - logger.info(f"Found checkpoint at `{str(use_file)}`") - return str(use_file) - else: - raise FileNotFoundError( - f"Could not find a checkpoint file at `{str(use_file)}`" - ) - - elif resume_from is None: - # use-case: user is re-starting a crashed/cancelled job - if last_path.is_file(): - logger.info(f"Found checkpoint at `{str(last_path)}`") - return str(last_path) - else: - return None - - elif isinstance(resume_from, str): - if pathlib.Path(resume_from).is_file(): - logger.info(f"Found checkpoint at `{resume_from}`") - return resume_from - else: - raise FileNotFoundError( - f"Could not find a checkpoint file at `{resume_from}`" - ) + + template = path / (CHECKPOINT_ALIASES[alias] + CHECKPOINT_EXTENSION) + + if template.exists(): + return template + + # otherwise, we see if we are looking for a template instead, in which case + # we must pick the latest. + assert "{epoch}" in str( + template + ), f"Template `{str(template)}` does not contain the keyword `{{epoch}}`" + + pattern = re.compile( + template.name.replace("{epoch}", r"epoch=(?P<epoch>\d+)") + ) + highest = -1 + for f in template.parent.iterdir(): + match = pattern.match(f.name) + if match is not None: + value = int(match.group("epoch")) + if value > highest: + highest = value + + if highest != -1: + return template.with_name( + template.name.replace("{epoch}", f"epoch={highest}") + ) + + raise FileNotFoundError( + f"A file matching `{str(template)}` specifications was not found" + ) + + +def get_checkpoint_to_resume_training( + path: pathlib.Path, +): + """Returns the best checkpoint file path to resume training from. + + Parameters + ---------- + path + The base directory containing either the "periodic" checkpoint to start + the training session from. + + + Returns + ------- + Path to a checkpoint file that exists on disk + + + Raises + ------ + FileNotFoundError + If none of the checkpoints can be found on the provided directory. + """ + + return _get_checkpoint_from_alias(path, "periodic") + + +def get_checkpoint_to_run_inference( + path: pathlib.Path, +): + """Returns the best checkpoint file path to run inference with. + + Parameters + ---------- + path + The base directory containing either the "best", "last" or "periodic" + checkpoint to start the training session from. + + + Returns + ------- + Path to a checkpoint file that exists on disk + + + Raises + ------ + FileNotFoundError + If none of the checkpoints can be found on the provided directory. + """ + + try: + return _get_checkpoint_from_alias(path, "best") + except FileNotFoundError: + logger.error( + "Did not find lowest-validation-loss model to run inference " + "from. Trying to search for the last periodically saved model..." + ) + + return _get_checkpoint_from_alias(path, "periodic") diff --git a/tests/data/test_vis_metrics.csv b/tests/data/test_vis_metrics.csv new file mode 100644 index 0000000000000000000000000000000000000000..1a144f21b6f21693d56be19ef301c240f30398e3 --- /dev/null +++ b/tests/data/test_vis_metrics.csv @@ -0,0 +1,6 @@ +Image,MoRF,LeRF,Combined Score ((LeRF-MoRF) / 2),IoU,IoDA,propEnergy,ASF +tb0004.png,1,2,3,4,5,6,7 +tb0006.png,2,3,4,5,6,7,8 +tb0009.png,1,2,3,4,5,6,7 +tb0014.png,2,3,4,5,6,7,8 +tb0015.png,1,2,3,4,5,6,7 diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..1fc4b4d4293fa401d1bc4bb3e99bcc0083653965 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..1cdd8a5fc1959009da23a9b5e25c4cf09f344532 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/ablationcam/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..92cb8bc6e0c454cc38d26f61d1c31420f14985eb Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..3a6aab0ebed12caca31920bde6a1bb1ae1cad7d0 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/fullgrad/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..c9cc1d12680c583a636aa60c3c28f8dd71d557ad Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..80fe19a9fb3a3700b787a83a2446083de6783c47 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/gradcam/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..abd1639622163c5a0d5ef8d32ff557ef157f63f0 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..6fae53802b4640edaeb106deab6252920b860abd Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/randomcam/targeted_class/train/tb0005.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/test/tb0004.png b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/test/tb0004.png new file mode 100644 index 0000000000000000000000000000000000000000..4819323f21e673c959edff58e4e34e1b82507169 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/test/tb0004.png differ diff --git a/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/train/tb0005.png b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/train/tb0005.png new file mode 100644 index 0000000000000000000000000000000000000000..8c10e34985fd278b5090a63033ae1fcb0d4c3903 Binary files /dev/null and b/tests/data/test_visualization_images/indirect-model/tbx11k/scorecam/targeted_class/train/tb0005.png differ diff --git a/tests/test_cli.py b/tests/test_cli.py index 5799e332172e32d27d03ff0e4f3efdc28135e421..1936aff43af8830ad4002ccf6a37a5e1b557c8d0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -154,14 +154,36 @@ def test_evaluate_help(): _check_help(evaluate) +def test_saliency_generate_help(): + from ptbench.scripts.saliency.generate import generate + + _check_help(generate) + + +def test_saliency_view_help(): + from ptbench.scripts.saliency.view import view + + _check_help(view) + + +def test_saliency_evaluate_help(): + from ptbench.scripts.saliency.evaluate import evaluate + + _check_help(evaluate) + + @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery(temporary_basedir): from ptbench.scripts.train import train + from ptbench.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() with stdout_logging() as buf: - output_folder = str(temporary_basedir / "results") + output_folder = temporary_basedir / "results" result = runner.invoke( train, [ @@ -170,17 +192,17 @@ def test_train_pasa_montgomery(temporary_basedir): "-vv", "--epochs=1", "--batch-size=1", - f"--output-folder={output_folder}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result) - assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.ckpt") - ) - assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.ckpt") - ) + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( len( @@ -194,8 +216,8 @@ def test_train_pasa_montgomery(temporary_basedir): keywords = { r"^Writing command-line for reproduction at .*$": 1, - r"^Loading dataset:`train` without caching. Trade-off: CPU RAM: less | Disk: more.$": 1, - r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM: less | Disk: more.$": 1, + r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying datamodule train sampler balancing...$": 1, r"^Balancing samples from dataset using metadata targets `label`$": 1, r"^Training for at most 1 epochs.$": 1, @@ -218,10 +240,14 @@ def test_train_pasa_montgomery(temporary_basedir): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): from ptbench.scripts.train import train + from ptbench.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() - output_folder = str(temporary_basedir / "results/pasa_checkpoint") + output_folder = temporary_basedir / "results" / "pasa_checkpoint" result0 = runner.invoke( train, [ @@ -230,15 +256,17 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): "-vv", "--epochs=1", "--batch-size=1", - f"--output-folder={output_folder}", + f"--output-folder={str(output_folder)}", ], ) _assert_exit_0(result0) - assert os.path.exists(os.path.join(output_folder, "model_final_epoch.ckpt")) - assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.ckpt") - ) + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) + assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( len( @@ -265,12 +293,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): ) _assert_exit_0(result) - assert os.path.exists( - os.path.join(output_folder, "model_final_epoch.ckpt") - ) - assert os.path.exists( - os.path.join(output_folder, "model_lowest_valid_loss.ckpt") - ) + # asserts checkpoints are there, or raises FileNotFoundError + last = _get_checkpoint_from_alias(output_folder, "periodic") + assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION) + best = _get_checkpoint_from_alias(output_folder, "best") + assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert ( @@ -286,12 +313,12 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): keywords = { r"^Writing command-line for reproduction at .*$": 1, - r"^Loading dataset:`train` without caching. Trade-off: CPU RAM: less | Disk: more.$": 1, - r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM: less | Disk: more.$": 1, + r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, + r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1, r"^Applying datamodule train sampler balancing...$": 1, r"^Balancing samples from dataset using metadata targets `label`$": 1, r"^Training for at most 2 epochs.$": 1, - r"^Resuming from epoch 0...$": 1, + r"^Resuming from epoch 0 \(checkpoint file: .*$": 1, r"^Saving model summary at.*$": 1, r"^Dataset `train` is already setup. Not re-instantiating it.$": 1, r"^Dataset `validation` is already setup. Not re-instantiating it.$": 1, @@ -312,11 +339,19 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") def test_predict_pasa_montgomery(temporary_basedir, datadir): from ptbench.scripts.predict import predict + from ptbench.utils.checkpointer import ( + CHECKPOINT_EXTENSION, + _get_checkpoint_from_alias, + ) runner = CliRunner() with stdout_logging() as buf: - output = str(temporary_basedir / "predictions") + output = temporary_basedir / "predictions" + last = _get_checkpoint_from_alias( + temporary_basedir / "results", "periodic" + ) + assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION) result = runner.invoke( predict, [ @@ -324,7 +359,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): "montgomery", "-vv", "--batch-size=1", - f"--weight={str(temporary_basedir / 'results' / 'model_final_epoch.ckpt')}", + f"--weight={str(last)}", f"--output={output}", ], ) @@ -333,7 +368,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): assert os.path.exists(output) keywords = { - r"^Loading dataset: * without caching. Trade-off: CPU RAM: less | Disk: more$": 3, + r"^Loading dataset: * without caching. Trade-off: CPU RAM usage: less | Disk I/O: more$": 3, r"^Loading checkpoint from .*$": 1, r"^Restoring normalizer from checkpoint.$": 1, r"^Running prediction on `train` split...$": 1, @@ -397,6 +432,45 @@ def test_evaluate_pasa_montgomery(temporary_basedir): ) +# This script does not work anymore, either fix or remove the script + this test +# def test_evaluatevis(temporary_basedir): +# import pandas as pd + +# from ptbench.scripts.evaluatevis import evaluatevis + +# runner = CliRunner() + +# # Create a sample directory structure and CSV files +# input_folder = temporary_basedir / "camutils_cli" / "gradcam" +# input_folder.mkdir(parents=True, exist_ok=True) +# class1_dir = input_folder / "class1" +# class1_dir.mkdir(parents=True, exist_ok=True) +# class2_dir = input_folder / "class2" +# class2_dir.mkdir(parents=True, exist_ok=True) + +# data = { +# "MoRF": [1, 2, 3], +# "LeRF": [2, 4, 6], +# "Combined Score ((LeRF-MoRF) / 2)": [1.5, 3, 4.5], +# "IoU": [1, 2, 3], +# "IoDA": [2, 4, 6], +# "propEnergy": [1.5, 3, 4.5], +# "ASF": [1, 2, 3], +# } +# df = pd.DataFrame(data) +# df.to_csv(class1_dir / "file1.csv", index=False) +# df.to_csv(class2_dir / "file1.csv", index=False) +# df.to_csv(class1_dir / "file2.csv", index=False) +# df.to_csv(class2_dir / "file2.csv", index=False) + +# result = runner.invoke(evaluatevis, ["-vv", "-i", str(input_folder)]) + +# assert result.exit_code == 0 + +# assert (input_folder / "file1_summary.csv").exists() +# assert (input_folder / "file2_summary.csv").exists() + + # Not enough RAM available to do this test # @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") # def test_predict_densenetrs_montgomery(temporary_basedir, datadir): diff --git a/tests/test_saliencymap_evaluator.py b/tests/test_saliencymap_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7845ab99e757a1912d9087659356e8819c3a37 --- /dev/null +++ b/tests/test_saliencymap_evaluator.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +import numpy as np + +from ptbench.config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes +from ptbench.engine.saliency.interpretability import ( + _compute_avg_saliency_focus, + _compute_binary_mask, + _compute_max_iou_and_ioda, + _compute_proportional_energy, + _compute_simultaneous_iou_and_ioda, + _process_sample, +) + + +def test_compute_max_iou_and_ioda(): + detected_box = BoundingBox(-1, 10, 10, 100, 100) + gt_box_dict = BoundingBox(1, 50, 50, 50, 50) + gt_box_dict2 = BoundingBox(1, 20, 20, 60, 60) + gt_boxes = BoundingBoxes([gt_box_dict, gt_box_dict2]) + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + expected_iou = 0.36 + expected_ioda = 0.36 + + assert iou == expected_iou + assert ioda == expected_ioda + + +def test_compute_max_iou_and_ioda_zero_detected_area(): + detected_box = BoundingBox(-1, 10, 10, 0, 0) + gt_box_dict = BoundingBox(1, 50, 50, 50, 50) + gt_boxes = BoundingBoxes([gt_box_dict]) + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + # Should be zero as the detected box has no area + assert iou == 0 + assert ioda == 0 + + +def test_compute_max_iou_and_ioda_zero_gt_area(): + detected_box = BoundingBox(-1, 10, 10, 100, 100) + gt_box_dict = BoundingBox(1, 50, 50, 0, 0) + gt_boxes = BoundingBoxes([gt_box_dict]) + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + # Should be zero as there is no ground truth box + assert iou == 0 + assert ioda == 0 + + +def test_compute_max_iou_and_ioda_zero_intersection(): + detected_box = BoundingBox(-1, 10, 10, 100, 100) + gt_box_dict = BoundingBox(1, 0, 0, 5, 5) + gt_boxes = BoundingBoxes([gt_box_dict]) + + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) + + assert iou == 0 + assert ioda == 0 + + +def test_compute_simultaneous_iou_and_ioda(): + detected_box = BoundingBox(-1, 10, 10, 100, 100) + gt_box_dict = BoundingBox(1, 50, 50, 50, 50) + gt_box_dict2 = BoundingBox(1, 70, 70, 30, 30) + + gt_boxes = BoundingBoxes([gt_box_dict, gt_box_dict2]) + + iou, ioda = _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes) + + assert iou == 0.34 + assert ioda == 0.34 + + +def test_compute_avg_saliency_focus(): + grayscale_cams = np.ones((200, 200)) + grayscale_cams2 = np.full((512, 512), 0.5) + grayscale_cams3 = np.zeros((256, 256)) + grayscale_cams3[50:75, 50:100] = 1 + gt_box_dict = BoundingBox(1, 50, 50, 50, 50) + gt_boxes = BoundingBoxes([gt_box_dict]) + + binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) + binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2) + binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) + + avg_saliency_focus = _compute_avg_saliency_focus( + grayscale_cams, binary_mask + ) + avg_saliency_focus2 = _compute_avg_saliency_focus( + grayscale_cams2, binary_mask2 + ) + avg_saliency_focus3 = _compute_avg_saliency_focus( + grayscale_cams3, binary_mask3 + ) + + assert avg_saliency_focus == 1 + assert avg_saliency_focus2 == 0.5 + assert avg_saliency_focus3 == 0.5 + + +def test_compute_avg_saliency_focus_no_activations(): + grayscale_cams = np.zeros((200, 200)) + gt_box_dict = BoundingBox(1, 50, 50, 50, 50) + gt_boxes = BoundingBoxes([gt_box_dict]) + + binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) + avg_saliency_focus = _compute_avg_saliency_focus( + grayscale_cams, binary_mask + ) + + assert avg_saliency_focus == 0 + + +def test_compute_avg_saliency_focus_zero_gt_area(): + grayscale_cams = np.ones((200, 200)) + gt_box_dict = BoundingBox(1, 50, 50, 0, 0) + gt_boxes = BoundingBoxes([gt_box_dict]) + + binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) + avg_saliency_focus = _compute_avg_saliency_focus( + grayscale_cams, binary_mask + ) + + assert avg_saliency_focus == 0 + + +def test_compute_proportional_energy(): + grayscale_cams = np.ones((200, 200)) + grayscale_cams2 = np.full((512, 512), 0.5) + grayscale_cams3 = np.zeros((512, 512)) + grayscale_cams3[100:200, 100:200] = 1 + gt_box_dict = BoundingBox(1, 50, 50, 100, 100) + gt_boxes = BoundingBoxes([gt_box_dict]) + + binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) + binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2) + binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) + + proportional_energy = _compute_proportional_energy( + grayscale_cams, binary_mask + ) + proportional_energy2 = _compute_proportional_energy( + grayscale_cams2, binary_mask2 + ) + proportional_energy3 = _compute_proportional_energy( + grayscale_cams3, binary_mask3 + ) + + assert proportional_energy == 0.25 + assert proportional_energy2 == 0.03814697265625 + assert proportional_energy3 == 0.25 + + +def test_compute_proportional_energy_no_activations(): + grayscale_cams = np.zeros((200, 200)) + gt_box_dict = BoundingBox(1, 50, 50, 50, 50) + gt_boxes = BoundingBoxes([gt_box_dict]) + + binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) + proportional_energy = _compute_proportional_energy( + grayscale_cams, binary_mask + ) + + assert proportional_energy == 0 + + +def test_compute_proportional_energy_no_gt_box(): + grayscale_cams = np.ones((200, 200)) + gt_box_dict = BoundingBox(1, 0, 0, 0, 0) + gt_boxes = BoundingBoxes([gt_box_dict]) + + binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) + proportional_energy = _compute_proportional_energy( + grayscale_cams, binary_mask + ) + + assert proportional_energy == 0 + + +def test_process_sample(): + grayscale_cams = np.ones((200, 200)) + gt_box_dict = BoundingBox(1, 50, 50, 0, 0) + gt_boxes = BoundingBoxes([gt_box_dict]) + + proportional_energy, avg_saliency_focus = _process_sample( + gt_boxes, grayscale_cams + ) + + assert proportional_energy == 0 + assert avg_saliency_focus == 0 diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 1b3a72225522420d8089a54108dcd978ab1a2c44..e7dae5678b1cc3ff18cc728a485f33d730de24f6 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -186,26 +186,26 @@ def check_loaded_batch( [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] ) - assert "radsign_bboxes" in batch[1] + assert "bounding_boxes" in batch[1] for sample, label, bboxes in zip( - batch[0], batch[1]["label"], batch[1]["radsign_bboxes"] + batch[0], batch[1]["label"], batch[1]["bounding_boxes"] ): # there must be a sign indicated on the image, if active TB is detected if label == 1: - assert len(bboxes[0]) != 0 + assert len(bboxes) != 0 # eif label == 0: # not true, may have TBI! # assert len(bboxes) == 0 # asserts all bounding boxes are within the raw image width and height - for bbox_label, xmin, ymin, width, height in zip(*bboxes): + for bbox in bboxes: if label == 1: - assert bbox_label == 1 + assert bbox.label == 1 else: - assert bbox_label == 0 - assert (xmin + width) < sample.shape[2] - assert (ymin + height) < sample.shape[1] + assert bbox.label == 0 + assert bbox.xmax < sample.shape[2] + assert bbox.ymax < sample.shape[1] # use the code below to view generated images # from torchvision.transforms.functional import to_pil_image diff --git a/tests/test_tranforms.py b/tests/test_transforms.py similarity index 100% rename from tests/test_tranforms.py rename to tests/test_transforms.py