From be4a633c661ff7f6b3099471384ea304ba16a2b1 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 15 Dec 2023 13:49:58 +0100 Subject: [PATCH] [engine.saliency.interpretability] Allow user to explicitly define target to be analysed --- .../engine/saliency/interpretability.py | 82 +++++++++++++++---- 1 file changed, 66 insertions(+), 16 deletions(-) diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py index 9ecf9096..62c7512b 100644 --- a/src/ptbench/engine/saliency/interpretability.py +++ b/src/ptbench/engine/saliency/interpretability.py @@ -23,14 +23,16 @@ logger = logging.getLogger(__name__) def _ordered_connected_components( saliency_map: typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double], + threshold: float, ) -> list[numpy.typing.NDArray[numpy.bool_]]: """Calculates the largest connected components available on a saliency map and return those as individual masks. This implementation is based on [SCORECAM-2020]_: - 1. Thresholding: The pixel values above 20% of max value are kept in the - original saliency map. Everything else is set to zero. + 1. Thresholding: The pixel values above ``threshold``% of max value are + kept in the original saliency map. Everything else is set to zero. The + value proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure. 2. The thresholded saliency map is transformed into a boolean array (ones are attributed to all elements above the threshold. 3. We call :py:func:`skimage.metrics.label` to evaluate all connected @@ -44,6 +46,10 @@ def _ordered_connected_components( saliency_map Input saliciency map whose connected components will be calculated from. + threshold + Relative threshold to be used to zero parts of the original saliency + map. A value of 0.2 will zero all values in the saliency map that are + bellow 20% of the maximum value observed in the said map. Returns @@ -54,9 +60,10 @@ def _ordered_connected_components( """ # thresholds like [SCORECAM-2020]_ - thresholded_mask = (saliency_map >= (0.2 * numpy.max(saliency_map))).astype( - numpy.uint8 - ) + saliency_array = numpy.array(saliency_map) + thresholded_mask = ( + saliency_array >= (threshold * saliency_array.max()) + ).astype(numpy.uint8) # avoids an all zeroes mask being processed if not numpy.any(thresholded_mask): @@ -131,6 +138,45 @@ def _compute_max_iou_and_ioda( return iou, ioda +def get_largest_bounding_boxes( + saliency_map: typing.Sequence[typing.Sequence[float]] + | numpy.typing.NDArray[numpy.double], + n: int, + threshold: float = 0.2, +) -> list[BoundingBox]: + """Returns the N largest connected components as bounding boxes in a + saliency map. + + The return of values is subject to the value of ``threshold`` applied, as + well as on the saliency map itself. The number of objects found is also + affected by those parameters. + + + saliency_map + Input saliciency map whose connected components will be calculated + from. + n + The number of connected components to search for in the saliency map. + Connected components are then translated to bounding-box notation. + threshold + Relative threshold to be used to zero parts of the original saliency + map. A value of 0.2 will zero all values in the saliency map that are + bellow 20% of the maximum value observed in the said map. + + + Returns + ------- + """ + + retval: list[BoundingBox] = [] + + masks = _ordered_connected_components(saliency_map, threshold) + if masks: + retval += [_extract_bounding_box(k) for k in masks[:n]] + + return retval + + def _compute_simultaneous_iou_and_ioda( detected_box: BoundingBox, gt_bboxes: BoundingBoxes, @@ -217,7 +263,7 @@ def _compute_proportional_energy( if denominator == 0.0: return 0.0 - return float(numpy.sum(saliency_map * gt_mask) / denominator) + return float(numpy.sum(saliency_map * gt_mask) / denominator) # type: ignore def _process_sample( @@ -243,11 +289,10 @@ def _process_sample( * Largest detected bounding box """ - masks = _ordered_connected_components(saliency_map) - detected_box = BoundingBox(-1, 0, 0, 0, 0) - if masks: - # we get only the largest bounding box as of now - detected_box = _extract_bounding_box(masks[0]) + largest_bbox = get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) + detected_box = ( + largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) + ) # Calculate localization metrics iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) @@ -271,6 +316,7 @@ def _process_sample( def run( input_folder: pathlib.Path, + target_label: int, datamodule: lightning.pytorch.LightningDataModule, ) -> dict[str, list[typing.Any]]: """Applies visualization techniques on input CXR, outputs images with @@ -281,6 +327,9 @@ def run( input_folder Directory in which the saliency maps are stored for a specific visualization type. + target_label + The label to target for evaluating interpretability metrics. Samples + contining any other label are ignored. datamodule The lightning datamodule to iterate on. @@ -318,6 +367,12 @@ def run( name = str(sample[1]["name"][0]) label = int(sample[1]["label"].item()) + if label != target_label: + # we add the entry for dataset completeness, but do not treat + # it + retval[dataset_name].append([name, label]) + continue + # TODO: This is very specific to the TBX11k system for labelling # regions of interest. We need to abstract from this to support more # datasets and other ways to annotate. @@ -325,11 +380,6 @@ def run( "bounding_boxes", BoundingBoxes() ) - if label == 0: - # we add the entry for dataset completeness - retval[dataset_name].append([name, label]) - continue - if not bboxes: logger.warning( f"Sample `{name}` does not contdain bounding-box information. " -- GitLab