-
André Anjos authoredAndré Anjos authored
test_saliencymap_interpretability.py 6.04 KiB
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import numpy as np
from mednet.classify.config.data.tbx11k.datamodule import (
BoundingBox,
BoundingBoxes,
)
from mednet.classify.engine.saliency.interpretability import (
_compute_avg_saliency_focus,
_compute_binary_mask,
_compute_max_iou_and_ioda,
_compute_proportional_energy,
_compute_simultaneous_iou_and_ioda,
_process_sample,
)
def test_compute_max_iou_and_ioda():
detected_box = BoundingBox(-1, 10, 10, 100, 100)
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_box_dict2 = BoundingBox(1, 20, 20, 60, 60)
gt_boxes = BoundingBoxes([gt_box_dict, gt_box_dict2])
iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes)
expected_iou = 0.36
expected_ioda = 0.36
assert iou == expected_iou
assert ioda == expected_ioda
def test_compute_max_iou_and_ioda_zero_detected_area():
detected_box = BoundingBox(-1, 10, 10, 0, 0)
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict])
iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes)
# Should be zero as the detected box has no area
assert iou == 0
assert ioda == 0
def test_compute_max_iou_and_ioda_zero_gt_area():
detected_box = BoundingBox(-1, 10, 10, 100, 100)
gt_box_dict = BoundingBox(1, 50, 50, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict])
iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes)
# Should be zero as there is no ground truth box
assert iou == 0
assert ioda == 0
def test_compute_max_iou_and_ioda_zero_intersection():
detected_box = BoundingBox(-1, 10, 10, 100, 100)
gt_box_dict = BoundingBox(1, 0, 0, 5, 5)
gt_boxes = BoundingBoxes([gt_box_dict])
iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes)
assert iou == 0
assert ioda == 0
def test_compute_simultaneous_iou_and_ioda():
detected_box = BoundingBox(-1, 10, 10, 100, 100)
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_box_dict2 = BoundingBox(1, 70, 70, 30, 30)
gt_boxes = BoundingBoxes([gt_box_dict, gt_box_dict2])
iou, ioda = _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes)
assert iou == 0.34
assert ioda == 0.34
def test_compute_avg_saliency_focus():
grayscale_cams = np.ones((200, 200))
grayscale_cams2 = np.full((512, 512), 0.5)
grayscale_cams3 = np.zeros((256, 256))
grayscale_cams3[50:75, 50:100] = 1
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict])
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2)
binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3)
avg_saliency_focus = _compute_avg_saliency_focus(
grayscale_cams,
binary_mask,
)
avg_saliency_focus2 = _compute_avg_saliency_focus(
grayscale_cams2,
binary_mask2,
)
avg_saliency_focus3 = _compute_avg_saliency_focus(
grayscale_cams3,
binary_mask3,
)
assert avg_saliency_focus == 1
assert avg_saliency_focus2 == 0.5
assert avg_saliency_focus3 == 0.5
def test_compute_avg_saliency_focus_no_activations():
grayscale_cams = np.zeros((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict])
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
avg_saliency_focus = _compute_avg_saliency_focus(
grayscale_cams,
binary_mask,
)
assert avg_saliency_focus == 0
def test_compute_avg_saliency_focus_zero_gt_area():
grayscale_cams = np.ones((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict])
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
avg_saliency_focus = _compute_avg_saliency_focus(
grayscale_cams,
binary_mask,
)
assert avg_saliency_focus == 0
def test_compute_proportional_energy():
grayscale_cams = np.ones((200, 200))
grayscale_cams2 = np.full((512, 512), 0.5)
grayscale_cams3 = np.zeros((512, 512))
grayscale_cams3[100:200, 100:200] = 1
gt_box_dict = BoundingBox(1, 50, 50, 100, 100)
gt_boxes = BoundingBoxes([gt_box_dict])
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2)
binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3)
proportional_energy = _compute_proportional_energy(
grayscale_cams,
binary_mask,
)
proportional_energy2 = _compute_proportional_energy(
grayscale_cams2,
binary_mask2,
)
proportional_energy3 = _compute_proportional_energy(
grayscale_cams3,
binary_mask3,
)
assert proportional_energy == 0.25
assert proportional_energy2 == 0.03814697265625
assert proportional_energy3 == 0.25
def test_compute_proportional_energy_no_activations():
grayscale_cams = np.zeros((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 50, 50)
gt_boxes = BoundingBoxes([gt_box_dict])
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
proportional_energy = _compute_proportional_energy(
grayscale_cams,
binary_mask,
)
assert proportional_energy == 0
def test_compute_proportional_energy_no_gt_box():
grayscale_cams = np.ones((200, 200))
gt_box_dict = BoundingBox(1, 0, 0, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict])
binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams)
proportional_energy = _compute_proportional_energy(
grayscale_cams,
binary_mask,
)
assert proportional_energy == 0
def test_process_sample():
grayscale_cams = np.ones((200, 200))
gt_box_dict = BoundingBox(1, 50, 50, 0, 0)
gt_boxes = BoundingBoxes([gt_box_dict])
proportional_energy, avg_saliency_focus = _process_sample(
gt_boxes,
grayscale_cams,
)
assert proportional_energy == 0
assert avg_saliency_focus == 0