From be4a633c661ff7f6b3099471384ea304ba16a2b1 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 15 Dec 2023 13:49:58 +0100
Subject: [PATCH] [engine.saliency.interpretability] Allow user to explicitly
 define target to be analysed

---
 .../engine/saliency/interpretability.py       | 82 +++++++++++++++----
 1 file changed, 66 insertions(+), 16 deletions(-)

diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py
index 9ecf9096..62c7512b 100644
--- a/src/ptbench/engine/saliency/interpretability.py
+++ b/src/ptbench/engine/saliency/interpretability.py
@@ -23,14 +23,16 @@ logger = logging.getLogger(__name__)
 def _ordered_connected_components(
     saliency_map: typing.Sequence[typing.Sequence[float]]
     | numpy.typing.NDArray[numpy.double],
+    threshold: float,
 ) -> list[numpy.typing.NDArray[numpy.bool_]]:
     """Calculates the largest connected components available on a saliency map
     and return those as individual masks.
 
     This implementation is based on [SCORECAM-2020]_:
 
-    1. Thresholding: The pixel values above 20% of max value are kept in the
-       original saliency map.  Everything else is set to zero.
+    1. Thresholding: The pixel values above ``threshold``% of max value are
+       kept in the original saliency map.  Everything else is set to zero.  The
+       value proposed on [SCORECAM-2020]_ is 0.2.  Use this value if unsure.
     2. The thresholded saliency map is transformed into a boolean array (ones
        are attributed to all elements above the threshold.
     3. We call :py:func:`skimage.metrics.label` to evaluate all connected
@@ -44,6 +46,10 @@ def _ordered_connected_components(
     saliency_map
         Input saliciency map whose connected components will be calculated
         from.
+    threshold
+        Relative threshold to be used to zero parts of the original saliency
+        map.  A value of 0.2 will zero all values in the saliency map that are
+        bellow 20% of the maximum value observed in the said map.
 
 
     Returns
@@ -54,9 +60,10 @@ def _ordered_connected_components(
     """
 
     # thresholds like [SCORECAM-2020]_
-    thresholded_mask = (saliency_map >= (0.2 * numpy.max(saliency_map))).astype(
-        numpy.uint8
-    )
+    saliency_array = numpy.array(saliency_map)
+    thresholded_mask = (
+        saliency_array >= (threshold * saliency_array.max())
+    ).astype(numpy.uint8)
 
     # avoids an all zeroes mask being processed
     if not numpy.any(thresholded_mask):
@@ -131,6 +138,45 @@ def _compute_max_iou_and_ioda(
     return iou, ioda
 
 
+def get_largest_bounding_boxes(
+    saliency_map: typing.Sequence[typing.Sequence[float]]
+    | numpy.typing.NDArray[numpy.double],
+    n: int,
+    threshold: float = 0.2,
+) -> list[BoundingBox]:
+    """Returns the N largest connected components as bounding boxes in a
+    saliency map.
+
+    The return of values is subject to the value of ``threshold`` applied, as
+    well as on the saliency map itself.  The number of objects found is also
+    affected by those parameters.
+
+
+    saliency_map
+        Input saliciency map whose connected components will be calculated
+        from.
+    n
+        The number of connected components to search for in the saliency map.
+        Connected components are then translated to bounding-box notation.
+    threshold
+        Relative threshold to be used to zero parts of the original saliency
+        map.  A value of 0.2 will zero all values in the saliency map that are
+        bellow 20% of the maximum value observed in the said map.
+
+
+    Returns
+    -------
+    """
+
+    retval: list[BoundingBox] = []
+
+    masks = _ordered_connected_components(saliency_map, threshold)
+    if masks:
+        retval += [_extract_bounding_box(k) for k in masks[:n]]
+
+    return retval
+
+
 def _compute_simultaneous_iou_and_ioda(
     detected_box: BoundingBox,
     gt_bboxes: BoundingBoxes,
@@ -217,7 +263,7 @@ def _compute_proportional_energy(
     if denominator == 0.0:
         return 0.0
 
-    return float(numpy.sum(saliency_map * gt_mask) / denominator)
+    return float(numpy.sum(saliency_map * gt_mask) / denominator)  # type: ignore
 
 
 def _process_sample(
@@ -243,11 +289,10 @@ def _process_sample(
         * Largest detected bounding box
     """
 
-    masks = _ordered_connected_components(saliency_map)
-    detected_box = BoundingBox(-1, 0, 0, 0, 0)
-    if masks:
-        # we get only the largest bounding box as of now
-        detected_box = _extract_bounding_box(masks[0])
+    largest_bbox = get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2)
+    detected_box = (
+        largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0)
+    )
 
     # Calculate localization metrics
     iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
@@ -271,6 +316,7 @@ def _process_sample(
 
 def run(
     input_folder: pathlib.Path,
+    target_label: int,
     datamodule: lightning.pytorch.LightningDataModule,
 ) -> dict[str, list[typing.Any]]:
     """Applies visualization techniques on input CXR, outputs images with
@@ -281,6 +327,9 @@ def run(
     input_folder
         Directory in which the saliency maps are stored for a specific
         visualization type.
+    target_label
+        The label to target for evaluating interpretability metrics. Samples
+        contining any other label are ignored.
     datamodule
         The lightning datamodule to iterate on.
 
@@ -318,6 +367,12 @@ def run(
             name = str(sample[1]["name"][0])
             label = int(sample[1]["label"].item())
 
+            if label != target_label:
+                # we add the entry for dataset completeness, but do not treat
+                # it
+                retval[dataset_name].append([name, label])
+                continue
+
             # TODO: This is very specific to the TBX11k system for labelling
             # regions of interest.  We need to abstract from this to support more
             # datasets and other ways to annotate.
@@ -325,11 +380,6 @@ def run(
                 "bounding_boxes", BoundingBoxes()
             )
 
-            if label == 0:
-                # we add the entry for dataset completeness
-                retval[dataset_name].append([name, label])
-                continue
-
             if not bboxes:
                 logger.warning(
                     f"Sample `{name}` does not contdain bounding-box information. "
-- 
GitLab