From f0e1e38ede2ca03a9b37af9c17ed8349c9f92606 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 30 Oct 2023 18:40:16 -0300 Subject: [PATCH] [config.data.tbx11k.datamodule] Re-define bounding-box containers for better pytorch integration --- src/ptbench/config/data/tbx11k/datamodule.py | 120 +++++++++++++++--- .../engine/saliency/interpretability.py | 102 +++++---------- src/ptbench/engine/visualizer.py | 4 +- tests/test_tbx11k.py | 4 +- 4 files changed, 139 insertions(+), 91 deletions(-) diff --git a/src/ptbench/config/data/tbx11k/datamodule.py b/src/ptbench/config/data/tbx11k/datamodule.py index 37ac5546..7d3d786e 100644 --- a/src/ptbench/config/data/tbx11k/datamodule.py +++ b/src/ptbench/config/data/tbx11k/datamodule.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import collections.abc +import dataclasses import importlib.resources import os import typing import PIL.Image +from torch.utils.data._utils.collate import default_collate_fn_map from torchvision.transforms.functional import to_tensor from ptbench.data.datamodule import CachingDataModule @@ -22,22 +25,102 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) database.""" -BoundingBoxAnnotation: typing.TypeAlias = tuple[int, int, int, int, int] -"""Location of TB radiological findings (latent or active) +@dataclasses.dataclass +class BoundingBox: + """Location of radiological findings. -Objects of this type carry bounding-box information of radiological findings on -the original 512x512 pixel images of TBX11k. The radiological findings are -defined as such: + Objects of this type carry bounding-box location of radiological findings + on the original images of TBX11k. The radiological findings are defined as + such: + + * 0/1: This labels the sign as latent TB (0), or active TB (1) + * xmin: horizontal position of bounding box upper-left corner, in pixels + * ymin: vertical position of bounding box upper-left corner, in pixels + * width: width of the bounding box, in pixels + * height: height of the bounding box, in pixels + """ + + label: int + xmin: int + ymin: int + width: int + height: int + + def area(self) -> int: + """Computes the bounding box area. + + Returns + ------- + The area in square-pixels. + """ + return self.width * self.height + + @property + def xmax(self) -> int: + return self.xmin + self.width - 1 + + @property + def ymax(self) -> int: + return self.ymin + self.height - 1 + + def intersection(self, other: typing.Self) -> int: + """Computes the area intersection between bounding boxes. + + Notice that screen geometry dictates is slightly different from + floating point metrics. Consider a 1D example for the evaluation of the + intersection: + + * 2 points : x1 = 1 and x2 = 3, the distance is indeed x2-x1 = 2 + * 2 pixels of index : i1 = 1 and i2 = 3, the segment from pixel i1 to + i2 contains 3 pixels ie l = i2 - i1 + 1 + + + Parameters + ---------- + other + The other bounding box to check intersections for + + Returns + ------- + The area intersection between this and the other bounding-box in + square pixels. + """ + dx = min(self.xmax, other.xmax) - max(self.xmin, other.xmin) + 1 + dy = min(self.ymax, other.ymax) - max(self.ymin, other.ymin) + 1 + + if dx >= 0 and dy >= 0: + return dx * dy + + return 0 + + +class BoundingBoxes(collections.abc.Sequence[BoundingBox]): + """A collection of bounding boxes.""" + + def __init__(self, t: typing.Sequence[BoundingBox] = []): + self.t = tuple(t) + + def __getitem__(self, index): + return self.t[index] + + def __len__(self) -> int: + return len(self.t) + + +# We update the default collate function map to use our custom function as +# explained at: +# https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate +def _collate_boundingboxes_fn(batch, *, collate_fn_map=None): + """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes + objects.""" + return batch + + +default_collate_fn_map.update({BoundingBoxes: _collate_boundingboxes_fn}) -* 0/1: This sign is for latent TB (0), or active TB (1) -* xmin: horizontal position of bounding box upper-left corner, in pixels -* ymin: vertical position of bounding box upper-left corner, in pixels -* width: width of the bounding box, in pixels -* height: height of the bounding box, in pixels -""" DatabaseSample: typing.TypeAlias = ( - tuple[str, int] | tuple[str, int, list[BoundingBoxAnnotation]] + tuple[str, int] | tuple[str, int, tuple[tuple[int, int, int, int, int]]] ) """Type of objects in our JSON representation for this database. @@ -90,7 +173,7 @@ class RawDataLoader(_BaseRawDataLoader): return tensor, dict( label=sample[1], name=sample[0], - radsign_bboxes=self.bbox_annotations(sample), + bounding_boxes=self.bounding_boxes(sample), ) def label(self, sample: DatabaseSample) -> int: @@ -111,10 +194,8 @@ class RawDataLoader(_BaseRawDataLoader): """ return sample[1] - def bbox_annotations( - self, sample: DatabaseSample - ) -> list[BoundingBoxAnnotation]: - """Loads a single image sample label from the disk. + def bounding_boxes(self, sample: DatabaseSample) -> BoundingBoxes: + """Loads image annotated bounding-boxes from the disk. Parameters ---------- @@ -129,7 +210,10 @@ class RawDataLoader(_BaseRawDataLoader): ------- Bounding box annotations, if any available with the sample. """ - return sample[2] if len(sample) > 2 else [] # type: ignore + if len(sample) > 2: + return BoundingBoxes([BoundingBox(*k) for k in sample[2]]) # type: ignore + + return BoundingBoxes() def make_split(basename: str) -> DatabaseSplit: diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py index ae9cca77..9ecf9096 100644 --- a/src/ptbench/engine/saliency/interpretability.py +++ b/src/ptbench/engine/saliency/interpretability.py @@ -15,6 +15,8 @@ import torchvision.ops from tqdm import tqdm +from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes + logger = logging.getLogger(__name__) @@ -74,7 +76,7 @@ def _ordered_connected_components( def _extract_bounding_box( mask: numpy.typing.NDArray[numpy.bool_], -) -> tuple[int, int, int, int]: +) -> BoundingBox: """Defines a bounding box surrounding a connected component mask. Parameters @@ -85,23 +87,17 @@ def _extract_bounding_box( Returns ------- - A tuple of 4 integers representing the bounding box with the following - components: - - * top-left horizontal coordinate (``x``) (pixels) - * top-left vertical coordinate (``y``) (pixels) - * width (pixels) - * height (pixels) + A bounding box """ x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[ 0 ] - return (int(x), int(y), int(x2 - x + 1), int(y2 - y + 1)) + return BoundingBox(-1, int(x), int(y), int(x2 - x + 1), int(y2 - y + 1)) def _compute_max_iou_and_ioda( - detected_box: tuple[int, int, int, int], - gt_bboxes: typing.Sequence[tuple[int, int, int, int]], + detected_box: BoundingBox, + gt_bboxes: BoundingBoxes, ) -> tuple[float, float]: """Will calculate how much of detected area lies in ground truth boxes. @@ -109,47 +105,35 @@ def _compute_max_iou_and_ioda( for each gt box separately and the gt box with the highest intersecting part will be used for the calculation. """ - x_L, y_L, w_L, h_L = detected_box - detected_area = w_L * h_L + detected_area = detected_box.area() if detected_area == 0: - return 0, 0 + return 0.0, 0.0 + max_intersection = 0 max_gt_area = 0 for bbox in gt_bboxes: - xmin, ymin, width, height = bbox - - gt_area = width * height - - # Calculate intersection - xi = max(x_L, xmin) - yi = max(y_L, ymin) - wi = max(0, min(x_L + w_L, xmin + width) - xi) - hi = max(0, min(y_L + h_L, ymin + height) - yi) - - intersection = 0 - if wi > 0 and hi > 0: - intersection = wi * hi - + intersection = bbox.intersection(detected_box) if intersection > max_intersection: max_intersection = intersection - max_gt_area = gt_area + max_gt_area = bbox.area() if max_gt_area == 0 and max_intersection == 0: # This case means no intersection was found, even though there are gt boxes iou, ioda = 0.0, 0.0 + else: iou = max_intersection / ( detected_area + max_gt_area - max_intersection ) ioda = max_intersection / detected_area - return float(iou), float(ioda) + return iou, ioda def _compute_simultaneous_iou_and_ioda( - detected_box: tuple[int, int, int, int], - gt_bboxes: typing.Sequence[tuple[int, int, int, int]], + detected_box: BoundingBox, + gt_bboxes: BoundingBoxes, ) -> tuple[float, float]: """Will calculate how much of detected area lies between ground truth boxes. @@ -158,27 +142,13 @@ def _compute_simultaneous_iou_and_ioda( will be compared to them simultaneously (and not to each gt box separately). """ - x_L, y_L, w_L, h_L = detected_box - detected_area = w_L * h_L + + detected_area = detected_box.area() if detected_area == 0: return 0, 0 - intersection = 0 - total_gt_area = 0 - - for bbox in gt_bboxes: - xmin, ymin, width, height = bbox - - gt_area = width * height - total_gt_area += gt_area - - # Calculate intersection - xi = max(x_L, xmin) - yi = max(y_L, ymin) - wi = max(0, min(x_L + w_L, xmin + width) - xi) - hi = max(0, min(y_L + h_L, ymin + height) - yi) - if wi > 0 and hi > 0: - intersection += wi * hi + intersection = sum([k.intersection(detected_box) for k in gt_bboxes]) + total_gt_area = sum([k.area() for k in gt_bboxes]) iou = intersection / (detected_area + total_gt_area - intersection) ioda = intersection / detected_area @@ -251,39 +221,30 @@ def _compute_proportional_energy( def _process_sample( - gt_bboxes: typing.Sequence[tuple[int, int, int, int]], + gt_bboxes: BoundingBoxes, saliency_map: numpy.typing.NDArray[numpy.double], -) -> tuple[float, float, float, float, int, int, int, int]: +) -> tuple[float, float, float, float, BoundingBox]: """Calculates the metrics for a single sample. Parameters ---------- - gt_bboxes - A list of ground-truth bounding boxes following the format: - - * xmin: horizontal position of bounding box upper-left corner, in - pixels - * ymin: vertical position of bounding box upper-left corner, in - pixels - * width: width of the bounding box, in pixels - * height: height of the bounding box, in pixels + A list of ground-truth bounding boxes. Returns ------- - A tuple containing the following values: * IoU * IoDA * Proportional energy * Average saliency focus - * Largest detected bounding box (x, y, width, height) + * Largest detected bounding box """ masks = _ordered_connected_components(saliency_map) - detected_box = (0, 0, 0, 0) + detected_box = BoundingBox(-1, 0, 0, 0, 0) if masks: # we get only the largest bounding box as of now detected_box = _extract_bounding_box(masks[0]) @@ -294,15 +255,17 @@ def _process_sample( # The binary_mask will be ON/True where the gt boxes are located binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_) for bbox in gt_bboxes: - xmin, ymin, width, height = bbox - binary_mask[ymin : ymin + height, xmin : xmin + width] = True + binary_mask[ + bbox.ymin : bbox.ymin + bbox.height, + bbox.xmin : bbox.xmin + bbox.width, + ] = True return ( iou, ioda, _compute_proportional_energy(saliency_map, binary_mask), _compute_avg_saliency_focus(saliency_map, binary_mask), - *detected_box, + detected_box, ) @@ -358,8 +321,9 @@ def run( # TODO: This is very specific to the TBX11k system for labelling # regions of interest. We need to abstract from this to support more # datasets and other ways to annotate. - bboxes = sample[1].get("radsign_bboxes", []) - bboxes = [k[1:] for k in bboxes] # remove bbox label... + bboxes: BoundingBoxes = sample[1].get( + "bounding_boxes", BoundingBoxes() + ) if label == 0: # we add the entry for dataset completeness diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py index 46f22e02..fed11421 100644 --- a/src/ptbench/engine/visualizer.py +++ b/src/ptbench/engine/visualizer.py @@ -103,11 +103,11 @@ def run( includes_bboxes = True if ( samples[1]["label"].item() == 0 - or "radsign_bboxes" not in samples[1] + or "bounding_boxes" not in samples[1] ): includes_bboxes = False else: - gt_bboxes = samples[1]["radsign_bboxes"] + gt_bboxes = samples[1]["bounding_boxes"] if not gt_bboxes: includes_bboxes = False diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py index 1b3a7222..a4c31edb 100644 --- a/tests/test_tbx11k.py +++ b/tests/test_tbx11k.py @@ -186,10 +186,10 @@ def check_loaded_batch( [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]] ) - assert "radsign_bboxes" in batch[1] + assert "bounding_boxes" in batch[1] for sample, label, bboxes in zip( - batch[0], batch[1]["label"], batch[1]["radsign_bboxes"] + batch[0], batch[1]["label"], batch[1]["bounding_boxes"] ): # there must be a sign indicated on the image, if active TB is detected if label == 1: -- GitLab