From ae43903177552858e1ef41029cdd6e115cd71c68 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 19 Jan 2024 13:50:47 +0100
Subject: [PATCH] Moved computation of binary mask to its own function

---
 .../engine/saliency/interpretability.py       | 42 +++++++++++++++----
 1 file changed, 35 insertions(+), 7 deletions(-)

diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py
index d56aa846..1275ea5d 100644
--- a/src/ptbench/engine/saliency/interpretability.py
+++ b/src/ptbench/engine/saliency/interpretability.py
@@ -260,6 +260,40 @@ def _compute_proportional_energy(
     return float(numpy.sum(saliency_map * gt_mask) / denominator)  # type: ignore
 
 
+def _compute_binary_mask(
+    gt_bboxes: BoundingBoxes,
+    saliency_map: numpy.typing.NDArray[numpy.double],
+) -> numpy.typing.NDArray[numpy.bool_]:
+    """Computes a binary mask for the saliency map using BoundingBoxes.
+
+    The binary_mask will be ON/True where the gt boxes are located.
+
+    Parameters
+    ----------
+        gt_bboxes
+            Ground-truth bounding boxes in the format ``(x, y, width,
+            height)``.
+
+        saliency_map
+            A real-valued saliency-map that conveys regions used for
+            classification in the original sample.
+
+
+    Returns
+    -------
+        A numpy array of the same size as saliency_map with
+        the value False everywhere except at the positions inside
+        the bounding boxes, which will be True.
+    """
+    binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_)
+    for bbox in gt_bboxes:
+        binary_mask[
+            bbox.ymin : bbox.ymin + bbox.height,
+            bbox.xmin : bbox.xmin + bbox.width,
+        ] = True
+    return binary_mask
+
+
 def _process_sample(
     gt_bboxes: BoundingBoxes,
     saliency_map: numpy.typing.NDArray[numpy.double],
@@ -291,13 +325,7 @@ def _process_sample(
     # # Calculate localization metrics
     # iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
 
-    # The binary_mask will be ON/True where the gt boxes are located
-    binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_)
-    for bbox in gt_bboxes:
-        binary_mask[
-            bbox.ymin : bbox.ymin + bbox.height,
-            bbox.xmin : bbox.xmin + bbox.width,
-        ] = True
+    binary_mask = _compute_binary_mask(gt_bboxes, saliency_map)
 
     return (
         # iou,
-- 
GitLab