# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np from mednet.config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes from mednet.engine.saliency.interpretability import ( _compute_avg_saliency_focus, _compute_binary_mask, _compute_max_iou_and_ioda, _compute_proportional_energy, _compute_simultaneous_iou_and_ioda, _process_sample, ) def test_compute_max_iou_and_ioda(): detected_box = BoundingBox(-1, 10, 10, 100, 100) gt_box_dict = BoundingBox(1, 50, 50, 50, 50) gt_box_dict2 = BoundingBox(1, 20, 20, 60, 60) gt_boxes = BoundingBoxes([gt_box_dict, gt_box_dict2]) iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) expected_iou = 0.36 expected_ioda = 0.36 assert iou == expected_iou assert ioda == expected_ioda def test_compute_max_iou_and_ioda_zero_detected_area(): detected_box = BoundingBox(-1, 10, 10, 0, 0) gt_box_dict = BoundingBox(1, 50, 50, 50, 50) gt_boxes = BoundingBoxes([gt_box_dict]) iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) # Should be zero as the detected box has no area assert iou == 0 assert ioda == 0 def test_compute_max_iou_and_ioda_zero_gt_area(): detected_box = BoundingBox(-1, 10, 10, 100, 100) gt_box_dict = BoundingBox(1, 50, 50, 0, 0) gt_boxes = BoundingBoxes([gt_box_dict]) iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) # Should be zero as there is no ground truth box assert iou == 0 assert ioda == 0 def test_compute_max_iou_and_ioda_zero_intersection(): detected_box = BoundingBox(-1, 10, 10, 100, 100) gt_box_dict = BoundingBox(1, 0, 0, 5, 5) gt_boxes = BoundingBoxes([gt_box_dict]) iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_boxes) assert iou == 0 assert ioda == 0 def test_compute_simultaneous_iou_and_ioda(): detected_box = BoundingBox(-1, 10, 10, 100, 100) gt_box_dict = BoundingBox(1, 50, 50, 50, 50) gt_box_dict2 = BoundingBox(1, 70, 70, 30, 30) gt_boxes = BoundingBoxes([gt_box_dict, gt_box_dict2]) iou, ioda = _compute_simultaneous_iou_and_ioda(detected_box, gt_boxes) assert iou == 0.34 assert ioda == 0.34 def test_compute_avg_saliency_focus(): grayscale_cams = np.ones((200, 200)) grayscale_cams2 = np.full((512, 512), 0.5) grayscale_cams3 = np.zeros((256, 256)) grayscale_cams3[50:75, 50:100] = 1 gt_box_dict = BoundingBox(1, 50, 50, 50, 50) gt_boxes = BoundingBoxes([gt_box_dict]) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2) binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) avg_saliency_focus = _compute_avg_saliency_focus( grayscale_cams, binary_mask ) avg_saliency_focus2 = _compute_avg_saliency_focus( grayscale_cams2, binary_mask2 ) avg_saliency_focus3 = _compute_avg_saliency_focus( grayscale_cams3, binary_mask3 ) assert avg_saliency_focus == 1 assert avg_saliency_focus2 == 0.5 assert avg_saliency_focus3 == 0.5 def test_compute_avg_saliency_focus_no_activations(): grayscale_cams = np.zeros((200, 200)) gt_box_dict = BoundingBox(1, 50, 50, 50, 50) gt_boxes = BoundingBoxes([gt_box_dict]) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) avg_saliency_focus = _compute_avg_saliency_focus( grayscale_cams, binary_mask ) assert avg_saliency_focus == 0 def test_compute_avg_saliency_focus_zero_gt_area(): grayscale_cams = np.ones((200, 200)) gt_box_dict = BoundingBox(1, 50, 50, 0, 0) gt_boxes = BoundingBoxes([gt_box_dict]) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) avg_saliency_focus = _compute_avg_saliency_focus( grayscale_cams, binary_mask ) assert avg_saliency_focus == 0 def test_compute_proportional_energy(): grayscale_cams = np.ones((200, 200)) grayscale_cams2 = np.full((512, 512), 0.5) grayscale_cams3 = np.zeros((512, 512)) grayscale_cams3[100:200, 100:200] = 1 gt_box_dict = BoundingBox(1, 50, 50, 100, 100) gt_boxes = BoundingBoxes([gt_box_dict]) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) binary_mask2 = _compute_binary_mask(gt_boxes, grayscale_cams2) binary_mask3 = _compute_binary_mask(gt_boxes, grayscale_cams3) proportional_energy = _compute_proportional_energy( grayscale_cams, binary_mask ) proportional_energy2 = _compute_proportional_energy( grayscale_cams2, binary_mask2 ) proportional_energy3 = _compute_proportional_energy( grayscale_cams3, binary_mask3 ) assert proportional_energy == 0.25 assert proportional_energy2 == 0.03814697265625 assert proportional_energy3 == 0.25 def test_compute_proportional_energy_no_activations(): grayscale_cams = np.zeros((200, 200)) gt_box_dict = BoundingBox(1, 50, 50, 50, 50) gt_boxes = BoundingBoxes([gt_box_dict]) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) proportional_energy = _compute_proportional_energy( grayscale_cams, binary_mask ) assert proportional_energy == 0 def test_compute_proportional_energy_no_gt_box(): grayscale_cams = np.ones((200, 200)) gt_box_dict = BoundingBox(1, 0, 0, 0, 0) gt_boxes = BoundingBoxes([gt_box_dict]) binary_mask = _compute_binary_mask(gt_boxes, grayscale_cams) proportional_energy = _compute_proportional_energy( grayscale_cams, binary_mask ) assert proportional_energy == 0 def test_process_sample(): grayscale_cams = np.ones((200, 200)) gt_box_dict = BoundingBox(1, 50, 50, 0, 0) gt_boxes = BoundingBoxes([gt_box_dict]) proportional_energy, avg_saliency_focus = _process_sample( gt_boxes, grayscale_cams ) assert proportional_energy == 0 assert avg_saliency_focus == 0