diff --git a/src/mednet/engine/saliency/interpretability.py b/src/mednet/engine/saliency/interpretability.py index a24c5f3fb8f2c3e74e471bcd49c478d92767e95c..65bd3614cf5d18dd104ca385357f0a4b3ed26d1b 100644 --- a/src/mednet/engine/saliency/interpretability.py +++ b/src/mednet/engine/saliency/interpretability.py @@ -19,12 +19,16 @@ from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes logger = logging.getLogger(__name__) +SaliencyMap: typing.TypeAlias = ( + typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double] +) +BinaryMask: typing.TypeAlias = numpy.typing.NDArray[numpy.bool_] + def _ordered_connected_components( - saliency_map: typing.Sequence[typing.Sequence[float]] - | numpy.typing.NDArray[numpy.double], + saliency_map: SaliencyMap, threshold: float, -) -> list[numpy.typing.NDArray[numpy.bool_]]: +) -> list[BinaryMask]: """Calculates the largest connected components available on a saliency map and return those as individual masks. @@ -76,7 +80,7 @@ def _ordered_connected_components( def _extract_bounding_box( - mask: numpy.typing.NDArray[numpy.bool_], + mask: BinaryMask, ) -> BoundingBox: """Defines a bounding box surrounding a connected component mask. @@ -147,8 +151,7 @@ def _compute_max_iou_and_ioda( def _get_largest_bounding_boxes( - saliency_map: typing.Sequence[typing.Sequence[float]] - | numpy.typing.NDArray[numpy.double], + saliency_map: SaliencyMap, n: int, threshold: float = 0.2, ) -> list[BoundingBox]: @@ -227,9 +230,37 @@ def _compute_simultaneous_iou_and_ioda( return float(iou), float(ioda) +def _compute_iou_ioda_from_largest_bbox( + gt_bboxes: BoundingBoxes, + saliency_map: SaliencyMap, +) -> tuple[float, float]: + """Calculates the metrics for a single sample. + + Parameters + ---------- + gt_bboxes + A list of ground-truth bounding boxes. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + + + Returns + ------- + A tuple containing the iou and ioda for the largest bounding box. + """ + + largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) + detected_box = ( + largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) + ) + iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) + return (iou, ioda) + + def _compute_avg_saliency_focus( - saliency_map: numpy.typing.NDArray[numpy.double], - gt_mask: numpy.typing.NDArray[numpy.bool_], + saliency_map: SaliencyMap, + gt_mask: BinaryMask, ) -> float: """Integrates the saliency map over the ground-truth boxes and normalizes by total bounding-box area. @@ -262,8 +293,8 @@ def _compute_avg_saliency_focus( def _compute_proportional_energy( - saliency_map: numpy.typing.NDArray[numpy.double], - gt_mask: numpy.typing.NDArray[numpy.bool_], + saliency_map: SaliencyMap, + gt_mask: BinaryMask, ) -> float: """Calculates how much activation lies within the ground truth boxes compared to the total sum of the activations (integral). @@ -293,8 +324,8 @@ def _compute_proportional_energy( def _compute_binary_mask( gt_bboxes: BoundingBoxes, - saliency_map: numpy.typing.NDArray[numpy.double], -) -> numpy.typing.NDArray[numpy.bool_]: + saliency_map: SaliencyMap, +) -> BinaryMask: """Computes a binary mask for the saliency map using BoundingBoxes. The binary_mask will be ON/True where the gt boxes are located. @@ -329,7 +360,7 @@ def _compute_binary_mask( def _process_sample( gt_bboxes: BoundingBoxes, - saliency_map: numpy.typing.NDArray[numpy.double], + saliency_map: SaliencyMap, ) -> tuple[float, float]: """Calculates the metrics for a single sample. @@ -346,34 +377,15 @@ def _process_sample( ------- A tuple containing the following values: - * IoU - * IoDA * Proportional energy * Average saliency focus - * Largest detected bounding box """ - # largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2) - # detected_box = ( - # largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0) - # ) - # - # # Calculate localization metrics - # iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes) - binary_mask = _compute_binary_mask(gt_bboxes, saliency_map) return ( - # iou, - # ioda, _compute_proportional_energy(saliency_map, binary_mask), _compute_avg_saliency_focus(saliency_map, binary_mask), - # ( - # detected_box.xmin, - # detected_box.ymin, - # detected_box.width, - # detected_box.height, - # ), )