diff --git a/conda/meta.yaml b/conda/meta.yaml index 1de05682e38e448fb80636367e470a9f5cd0924d..72dc7e004c665fe2f733252b80a933eefad8f6d3 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -30,6 +30,7 @@ requirements: - pillow {{ pillow }} - psutil {{ psutil }} - pytorch {{ pytorch }} + - scikit-image {{ scikit_image }} - scikit-learn {{ scikit_learn }} - scipy {{ scipy }} - tabulate {{ tabulate }} @@ -46,6 +47,7 @@ requirements: - {{ pin_compatible('pillow') }} - {{ pin_compatible('psutil') }} - {{ pin_compatible('pytorch') }} + - {{ pin_compatible('scikit-image') }} - {{ pin_compatible('scikit-learn') }} - {{ pin_compatible('scipy') }} - {{ pin_compatible('tabulate') }} diff --git a/pyproject.toml b/pyproject.toml index 033858cdea2c17fb5c497d9b011b59806a726c32..c80e518ff767e64a8b78bb3c933c23ffc1a78580 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "click", "numpy", "scipy", + "scikit-image", "scikit-learn", "tqdm", "psutil", diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py index 25ebd01fad4ed4d123d87159b921eaaa9ed15e5a..d56aa8469424f0b775e6bf0a3a807d078db2be73 100644 --- a/src/ptbench/engine/saliency/interpretability.py +++ b/src/ptbench/engine/saliency/interpretability.py @@ -6,10 +6,10 @@ import logging import pathlib import typing -import cv2 import lightning.pytorch import numpy import numpy.typing +import skimage.measure import torch import torchvision.ops @@ -69,14 +69,8 @@ def _ordered_connected_components( if not numpy.any(thresholded_mask): return [] - # opencv implementation: - n, labelled = cv2.connectedComponents(thresholded_mask, connectivity=8) - retval = [labelled == k for k in range(1, n)] - - # scikit-image implementation - # import skimage.measure - # labelled, n = skimage.measure.label(thresholded_mask, return_num=True) - # retval = [labelled == k for k in range(1, n+1)] + labelled, n = skimage.measure.label(thresholded_mask, return_num=True) # type: ignore + retval = [labelled == k for k in range(1, n + 1)] return sorted(retval, key=lambda x: x.sum(), reverse=True) @@ -138,7 +132,7 @@ def _compute_max_iou_and_ioda( return iou, ioda -def get_largest_bounding_boxes( +def _get_largest_bounding_boxes( saliency_map: typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double], n: int, @@ -269,7 +263,7 @@ def _compute_proportional_energy( def _process_sample( gt_bboxes: BoundingBoxes, saliency_map: numpy.typing.NDArray[numpy.double], -) -> tuple[float, float, float, float, tuple[int, int, int, int]]: +) -> tuple[float, float]: """Calculates the metrics for a single sample. Parameters @@ -289,13 +283,13 @@ def _process_sample( * Largest detected bounding box """ - largest_bbox = get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) - detected_box = ( - largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) - ) - - # Calculate localization metrics - iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) + # largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) + # detected_box = ( + # largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) + # ) + # + # # Calculate localization metrics + # iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) # The binary_mask will be ON/True where the gt boxes are located binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_) @@ -306,16 +300,16 @@ def _process_sample( ] = True return ( - iou, - ioda, + # iou, + # ioda, _compute_proportional_energy(saliency_map, binary_mask), _compute_avg_saliency_focus(saliency_map, binary_mask), - ( - detected_box.xmin, - detected_box.ymin, - detected_box.width, - detected_box.height, - ), + # ( + # detected_box.xmin, + # detected_box.ymin, + # detected_box.width, + # detected_box.height, + # ), ) @@ -348,11 +342,8 @@ def run( * Sample name (str) * Sample target class (int) - * IoU (float) - * IoDA (float) * Proportional energy (float) * Average saliency focus (float) - * Largest detected bounding box (x, y, width, height) (4 x int) """ retval: dict[str, list[typing.Any]] = {} diff --git a/src/ptbench/scripts/saliency/interpretability.py b/src/ptbench/scripts/saliency/interpretability.py index 978bda7a841d0b25f6d552dd004aa367219f1512..6982b8700c66ea42e374ace4ff2dad4ed16ca8c9 100644 --- a/src/ptbench/scripts/saliency/interpretability.py +++ b/src/ptbench/scripts/saliency/interpretability.py @@ -94,7 +94,7 @@ def interpretability( .. note:: - For obvious reasons, this evaluation is limited to databases that + For obvious reasons, this evaluation is limited to datasets that contain built-in annotations which corroborate classification. @@ -102,11 +102,6 @@ def interpretability( that resembles the original datamodule, with added information containing the following measures, for each sample: - * IoU: The intersection of the (thresholded) saliency maps with - the annotation the most overlaps, over the union of both areas. - * IoDA: The intersection of the (thresholded) saliency maps with - the annotation that most overlaps, over area of (thresholded) saliency - maps. * Proportional Energy: A measure that compares (UNthresholed) saliency maps with annotations (based on [SCORECAM-2020]_). It estimates how much activation lies within the ground truth boxes compared to the total sum @@ -115,21 +110,6 @@ def interpretability( boxes area is covered by the activations. It is similar to the proportional energy measure in the sense it does not need explicit thresholding. - - .. important:: - - The thresholding algorithm used to evaluate IoU and IoDA measures is - based on the process done by the original CAM paper [GRADCAM-2015]_. It - keeps all points from the saliency map that are above the 20% of its - maximum value. - - It then calculates a **single** bounding box for largest connected - component. This bounding box represents detected elements on the - original sample that corroborate the classification outcome. - - IoU and IoDA are only evaluated for a single ground-truth bounding box - per sample (the one with the highest overlap). Any other bounding box - marked on the sample is ignored in the present implementation. """ import json