Skip to content
Snippets Groups Projects
Commit bf95cf40 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[saliency] Update interpretability script

parent 34d41324
No related branches found
No related tags found
2 merge requests!58Make classification sample data into a dict and use tv_tensors,!46Create common library
...@@ -9,10 +9,9 @@ import typing ...@@ -9,10 +9,9 @@ import typing
import lightning.pytorch import lightning.pytorch
import numpy import numpy
import numpy.typing import numpy.typing
from torchvision import tv_tensors
from tqdm import tqdm from tqdm import tqdm
from ...config.data.tbx11k.datamodule import BoundingBoxes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SaliencyMap: typing.TypeAlias = ( SaliencyMap: typing.TypeAlias = (
...@@ -85,7 +84,7 @@ def _compute_proportional_energy( ...@@ -85,7 +84,7 @@ def _compute_proportional_energy(
def _compute_binary_mask( def _compute_binary_mask(
gt_bboxes: BoundingBoxes, gt_bboxes: tv_tensors.BoundingBoxes,
saliency_map: SaliencyMap, saliency_map: SaliencyMap,
) -> BinaryMask: ) -> BinaryMask:
"""Compute a binary mask for the saliency map using BoundingBoxes. """Compute a binary mask for the saliency map using BoundingBoxes.
...@@ -110,16 +109,21 @@ def _compute_binary_mask( ...@@ -110,16 +109,21 @@ def _compute_binary_mask(
""" """
binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_) binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_)
if gt_bboxes.format != tv_tensors.BoundingBoxFormat.XYXY:
raise ValueError(
f"Only boundingBoxes of format xyxy are supported. Got {gt_bboxes.format}."
)
for bbox in gt_bboxes: for bbox in gt_bboxes:
binary_mask[ binary_mask[
bbox.ymin : bbox.ymin + bbox.height, bbox.data[1] : bbox.data[1] + (bbox.data[3] - bbox.data[1]),
bbox.xmin : bbox.xmin + bbox.width, bbox.data[0] : bbox.data[0] + (bbox.data[2] - bbox.data[0]),
] = True ] = True
return binary_mask return binary_mask
def _process_sample( def _process_sample(
gt_bboxes: BoundingBoxes, gt_bboxes: tv_tensors.BoundingBoxes,
saliency_map: SaliencyMap, saliency_map: SaliencyMap,
) -> tuple[float, float]: ) -> tuple[float, float]:
"""Calculate the metrics for a single sample. """Calculate the metrics for a single sample.
...@@ -219,12 +223,12 @@ def run( ...@@ -219,12 +223,12 @@ def run(
# TODO: This is very specific to the TBX11k system for labelling # TODO: This is very specific to the TBX11k system for labelling
# regions of interest. We need to abstract from this to support more # regions of interest. We need to abstract from this to support more
# datasets and other ways to annotate. # datasets and other ways to annotate.
bboxes: BoundingBoxes = sample[1].get( bboxes: tv_tensors.BoundingBoxes = sample[0].get(
"bounding_boxes", "bounding_boxes",
BoundingBoxes(), None,
) )
if not bboxes: if bboxes is None:
logger.warning( logger.warning(
f"Sample `{name}` does not contain bounding-box information. " f"Sample `{name}` does not contain bounding-box information. "
f"No localization metrics can be calculated in this case. " f"No localization metrics can be calculated in this case. "
......
...@@ -2,16 +2,13 @@ ...@@ -2,16 +2,13 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import numpy as np import numpy as np
from mednet.classify.config.data.tbx11k.datamodule import (
BoundingBox,
BoundingBoxes,
)
from mednet.classify.engine.saliency.interpretability import ( from mednet.classify.engine.saliency.interpretability import (
_compute_avg_saliency_focus, _compute_avg_saliency_focus,
_compute_binary_mask, _compute_binary_mask,
_compute_proportional_energy, _compute_proportional_energy,
_process_sample, _process_sample,
) )
from torchvision import tv_tensors
def test_compute_avg_saliency_focus(): def test_compute_avg_saliency_focus():
...@@ -19,12 +16,21 @@ def test_compute_avg_saliency_focus(): ...@@ -19,12 +16,21 @@ def test_compute_avg_saliency_focus():
grayscale_cams2 = np.full((512, 512), 0.5) grayscale_cams2 = np.full((512, 512), 0.5)
grayscale_cams3 = np.zeros((256, 256)) grayscale_cams3 = np.zeros((256, 256))
grayscale_cams3[50:75, 50:100] = 1 grayscale_cams3[50:75, 50:100] = 1
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict])
bbox_data = [50, 50, 100, 100]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2) gt_boxes2 = tv_tensors.BoundingBoxes(
binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask2 = _compute_binary_mask(gt_boxes2, grayscale_cams2)
gt_boxes3 = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask3 = _compute_binary_mask(gt_boxes3, grayscale_cams3)
avg_saliency_focus = _compute_avg_saliency_focus( avg_saliency_focus = _compute_avg_saliency_focus(
grayscale_cams, grayscale_cams,
...@@ -46,8 +52,12 @@ def test_compute_avg_saliency_focus(): ...@@ -46,8 +52,12 @@ def test_compute_avg_saliency_focus():
def test_compute_avg_saliency_focus_no_activations(): def test_compute_avg_saliency_focus_no_activations():
grayscale_cams = np.zeros((200, 200)) grayscale_cams = np.zeros((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict]) bbox_data = [50, 50, 100, 100]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
avg_saliency_focus = _compute_avg_saliency_focus( avg_saliency_focus = _compute_avg_saliency_focus(
...@@ -60,8 +70,12 @@ def test_compute_avg_saliency_focus_no_activations(): ...@@ -60,8 +70,12 @@ def test_compute_avg_saliency_focus_no_activations():
def test_compute_avg_saliency_focus_zero_gt_area(): def test_compute_avg_saliency_focus_zero_gt_area():
grayscale_cams = np.ones((200, 200)) grayscale_cams = np.ones((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict]) bbox_data = [50, 50, 50, 50]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
avg_saliency_focus = _compute_avg_saliency_focus( avg_saliency_focus = _compute_avg_saliency_focus(
...@@ -77,12 +91,21 @@ def test_compute_proportional_energy(): ...@@ -77,12 +91,21 @@ def test_compute_proportional_energy():
grayscale_cams2 = np.full((512, 512), 0.5) grayscale_cams2 = np.full((512, 512), 0.5)
grayscale_cams3 = np.zeros((512, 512)) grayscale_cams3 = np.zeros((512, 512))
grayscale_cams3[100:200, 100:200] = 1 grayscale_cams3[100:200, 100:200] = 1
gt_box_dict = BoundingBox(1, 50, 50, 100, 100)
gt_boxes = BoundingBoxes([gt_box_dict])
bbox_data = [50, 50, 150, 150]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2) gt_boxes2 = tv_tensors.BoundingBoxes(
binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) data=bbox_data, format="XYXY", canvas_size=grayscale_cams2.shape
)
binary_mask2 = _compute_binary_mask(gt_boxes2, grayscale_cams2)
gt_boxes3 = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams3.shape
)
binary_mask3 = _compute_binary_mask(gt_boxes3, grayscale_cams3)
proportional_energy = _compute_proportional_energy( proportional_energy = _compute_proportional_energy(
grayscale_cams, grayscale_cams,
...@@ -104,8 +127,12 @@ def test_compute_proportional_energy(): ...@@ -104,8 +127,12 @@ def test_compute_proportional_energy():
def test_compute_proportional_energy_no_activations(): def test_compute_proportional_energy_no_activations():
grayscale_cams = np.zeros((200, 200)) grayscale_cams = np.zeros((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict]) bbox_data = [50, 50, 150, 150]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
proportional_energy = _compute_proportional_energy( proportional_energy = _compute_proportional_energy(
...@@ -118,8 +145,12 @@ def test_compute_proportional_energy_no_activations(): ...@@ -118,8 +145,12 @@ def test_compute_proportional_energy_no_activations():
def test_compute_proportional_energy_no_gt_box(): def test_compute_proportional_energy_no_gt_box():
grayscale_cams = np.ones((200, 200)) grayscale_cams = np.ones((200, 200))
gt_box_dict = BoundingBox(1, 0, 0, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict]) bbox_data = [0, 0, 0, 0]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
proportional_energy = _compute_proportional_energy( proportional_energy = _compute_proportional_energy(
...@@ -132,8 +163,12 @@ def test_compute_proportional_energy_no_gt_box(): ...@@ -132,8 +163,12 @@ def test_compute_proportional_energy_no_gt_box():
def test_process_sample(): def test_process_sample():
grayscale_cams = np.ones((200, 200)) grayscale_cams = np.ones((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict]) bbox_data = [50, 50, 50, 50]
gt_boxes = tv_tensors.BoundingBoxes(
data=bbox_data, format="XYXY", canvas_size=grayscale_cams.shape
)
proportional_energy, avg_saliency_focus = _process_sample( proportional_energy, avg_saliency_focus = _process_sample(
gt_boxes, gt_boxes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment