Skip to content
Snippets Groups Projects
Commit 14de2dc6 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[doc] Cleaned up interpretability _process_sample()

parent 9e60e08c
No related branches found
No related tags found
1 merge request!15Update documentation
......@@ -19,12 +19,16 @@ from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
logger = logging.getLogger(__name__)
SaliencyMap: typing.TypeAlias = (
typing.Sequence[typing.Sequence[float]] | numpy.typing.NDArray[numpy.double]
)
BinaryMask: typing.TypeAlias = numpy.typing.NDArray[numpy.bool_]
def _ordered_connected_components(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
saliency_map: SaliencyMap,
threshold: float,
) -> list[numpy.typing.NDArray[numpy.bool_]]:
) -> list[BinaryMask]:
"""Calculates the largest connected components available on a saliency map
and return those as individual masks.
......@@ -76,7 +80,7 @@ def _ordered_connected_components(
def _extract_bounding_box(
mask: numpy.typing.NDArray[numpy.bool_],
mask: BinaryMask,
) -> BoundingBox:
"""Defines a bounding box surrounding a connected component mask.
......@@ -147,8 +151,7 @@ def _compute_max_iou_and_ioda(
def _get_largest_bounding_boxes(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
saliency_map: SaliencyMap,
n: int,
threshold: float = 0.2,
) -> list[BoundingBox]:
......@@ -227,9 +230,37 @@ def _compute_simultaneous_iou_and_ioda(
return float(iou), float(ioda)
def _compute_iou_ioda_from_largest_bbox(
gt_bboxes: BoundingBoxes,
saliency_map: SaliencyMap,
) -> tuple[float, float]:
"""Calculates the metrics for a single sample.
Parameters
----------
gt_bboxes
A list of ground-truth bounding boxes.
saliency_map
A real-valued saliency-map that conveys regions used for
classification in the original sample.
Returns
-------
A tuple containing the iou and ioda for the largest bounding box.
"""
largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2)
detected_box = (
largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0)
)
iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
return (iou, ioda)
def _compute_avg_saliency_focus(
saliency_map: numpy.typing.NDArray[numpy.double],
gt_mask: numpy.typing.NDArray[numpy.bool_],
saliency_map: SaliencyMap,
gt_mask: BinaryMask,
) -> float:
"""Integrates the saliency map over the ground-truth boxes and normalizes
by total bounding-box area.
......@@ -262,8 +293,8 @@ def _compute_avg_saliency_focus(
def _compute_proportional_energy(
saliency_map: numpy.typing.NDArray[numpy.double],
gt_mask: numpy.typing.NDArray[numpy.bool_],
saliency_map: SaliencyMap,
gt_mask: BinaryMask,
) -> float:
"""Calculates how much activation lies within the ground truth boxes
compared to the total sum of the activations (integral).
......@@ -293,8 +324,8 @@ def _compute_proportional_energy(
def _compute_binary_mask(
gt_bboxes: BoundingBoxes,
saliency_map: numpy.typing.NDArray[numpy.double],
) -> numpy.typing.NDArray[numpy.bool_]:
saliency_map: SaliencyMap,
) -> BinaryMask:
"""Computes a binary mask for the saliency map using BoundingBoxes.
The binary_mask will be ON/True where the gt boxes are located.
......@@ -329,7 +360,7 @@ def _compute_binary_mask(
def _process_sample(
gt_bboxes: BoundingBoxes,
saliency_map: numpy.typing.NDArray[numpy.double],
saliency_map: SaliencyMap,
) -> tuple[float, float]:
"""Calculates the metrics for a single sample.
......@@ -346,34 +377,15 @@ def _process_sample(
-------
A tuple containing the following values:
* IoU
* IoDA
* Proportional energy
* Average saliency focus
* Largest detected bounding box
"""
# largest_bbox = _get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2)
# detected_box = (
# largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0)
# )
#
# # Calculate localization metrics
# iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
binary_mask = _compute_binary_mask(gt_bboxes, saliency_map)
return (
# iou,
# ioda,
_compute_proportional_energy(saliency_map, binary_mask),
_compute_avg_saliency_focus(saliency_map, binary_mask),
# (
# detected_box.xmin,
# detected_box.ymin,
# detected_box.width,
# detected_box.height,
# ),
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment