From ae43903177552858e1ef41029cdd6e115cd71c68 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Fri, 19 Jan 2024 13:50:47 +0100 Subject: [PATCH] Moved computation of binary mask to its own function --- .../engine/saliency/interpretability.py | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py index d56aa846..1275ea5d 100644 --- a/src/ptbench/engine/saliency/interpretability.py +++ b/src/ptbench/engine/saliency/interpretability.py @@ -260,6 +260,40 @@ def _compute_proportional_energy( return float(numpy.sum(saliency_map * gt_mask) / denominator) # type: ignore +def _compute_binary_mask( + gt_bboxes: BoundingBoxes, + saliency_map: numpy.typing.NDArray[numpy.double], +) -> numpy.typing.NDArray[numpy.bool_]: + """Computes a binary mask for the saliency map using BoundingBoxes. + + The binary_mask will be ON/True where the gt boxes are located. + + Parameters + ---------- + gt_bboxes + Ground-truth bounding boxes in the format ``(x, y, width, + height)``. + + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + + + Returns + ------- + A numpy array of the same size as saliency_map with + the value False everywhere except at the positions inside + the bounding boxes, which will be True. + """ + binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_) + for bbox in gt_bboxes: + binary_mask[ + bbox.ymin : bbox.ymin + bbox.height, + bbox.xmin : bbox.xmin + bbox.width, + ] = True + return binary_mask + + def _process_sample( gt_bboxes: BoundingBoxes, saliency_map: numpy.typing.NDArray[numpy.double], @@ -291,13 +325,7 @@ def _process_sample( # # Calculate localization metrics # iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) - # The binary_mask will be ON/True where the gt boxes are located - binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_) - for bbox in gt_bboxes: - binary_mask[ - bbox.ymin : bbox.ymin + bbox.height, - bbox.xmin : bbox.xmin + bbox.width, - ] = True + binary_mask = _compute_binary_mask(gt_bboxes, saliency_map) return ( # iou, -- GitLab