From f0e1e38ede2ca03a9b37af9c17ed8349c9f92606 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 30 Oct 2023 18:40:16 -0300
Subject: [PATCH] [config.data.tbx11k.datamodule] Re-define bounding-box
 containers for better pytorch integration

---
 src/ptbench/config/data/tbx11k/datamodule.py  | 120 +++++++++++++++---
 .../engine/saliency/interpretability.py       | 102 +++++----------
 src/ptbench/engine/visualizer.py              |   4 +-
 tests/test_tbx11k.py                          |   4 +-
 4 files changed, 139 insertions(+), 91 deletions(-)

diff --git a/src/ptbench/config/data/tbx11k/datamodule.py b/src/ptbench/config/data/tbx11k/datamodule.py
index 37ac5546..7d3d786e 100644
--- a/src/ptbench/config/data/tbx11k/datamodule.py
+++ b/src/ptbench/config/data/tbx11k/datamodule.py
@@ -2,12 +2,15 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import collections.abc
+import dataclasses
 import importlib.resources
 import os
 import typing
 
 import PIL.Image
 
+from torch.utils.data._utils.collate import default_collate_fn_map
 from torchvision.transforms.functional import to_tensor
 
 from ptbench.data.datamodule import CachingDataModule
@@ -22,22 +25,102 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 database."""
 
 
-BoundingBoxAnnotation: typing.TypeAlias = tuple[int, int, int, int, int]
-"""Location of TB radiological findings (latent or active)
+@dataclasses.dataclass
+class BoundingBox:
+    """Location of radiological findings.
 
-Objects of this type carry bounding-box information of radiological findings on
-the original 512x512 pixel images of TBX11k.  The radiological findings are
-defined as such:
+    Objects of this type carry bounding-box location of radiological findings
+    on the original images of TBX11k.  The radiological findings are defined as
+    such:
+
+    * 0/1: This labels the sign as latent TB (0), or active TB (1)
+    * xmin: horizontal position of bounding box upper-left corner, in pixels
+    * ymin: vertical position of bounding box upper-left corner, in pixels
+    * width: width of the bounding box, in pixels
+    * height: height of the bounding box, in pixels
+    """
+
+    label: int
+    xmin: int
+    ymin: int
+    width: int
+    height: int
+
+    def area(self) -> int:
+        """Computes the bounding box area.
+
+        Returns
+        -------
+            The area in square-pixels.
+        """
+        return self.width * self.height
+
+    @property
+    def xmax(self) -> int:
+        return self.xmin + self.width - 1
+
+    @property
+    def ymax(self) -> int:
+        return self.ymin + self.height - 1
+
+    def intersection(self, other: typing.Self) -> int:
+        """Computes the area intersection between bounding boxes.
+
+        Notice that screen geometry dictates is slightly different from
+        floating point metrics. Consider a 1D example for the evaluation of the
+        intersection:
+
+        * 2 points : x1 = 1 and x2 = 3, the distance is indeed x2-x1 = 2
+        * 2 pixels of index : i1 = 1 and i2 = 3, the segment from pixel i1 to
+          i2 contains 3 pixels ie l = i2 - i1 + 1
+
+
+        Parameters
+        ----------
+        other
+            The other bounding box to check intersections for
+
+        Returns
+        -------
+            The area intersection between this and the other bounding-box in
+            square pixels.
+        """
+        dx = min(self.xmax, other.xmax) - max(self.xmin, other.xmin) + 1
+        dy = min(self.ymax, other.ymax) - max(self.ymin, other.ymin) + 1
+
+        if dx >= 0 and dy >= 0:
+            return dx * dy
+
+        return 0
+
+
+class BoundingBoxes(collections.abc.Sequence[BoundingBox]):
+    """A collection of bounding boxes."""
+
+    def __init__(self, t: typing.Sequence[BoundingBox] = []):
+        self.t = tuple(t)
+
+    def __getitem__(self, index):
+        return self.t[index]
+
+    def __len__(self) -> int:
+        return len(self.t)
+
+
+# We update the default collate function map to use our custom function as
+# explained at:
+# https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate
+def _collate_boundingboxes_fn(batch, *, collate_fn_map=None):
+    """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes
+    objects."""
+    return batch
+
+
+default_collate_fn_map.update({BoundingBoxes: _collate_boundingboxes_fn})
 
-* 0/1: This sign is for latent TB (0), or active TB (1)
-* xmin: horizontal position of bounding box upper-left corner, in pixels
-* ymin: vertical position of bounding box upper-left corner, in pixels
-* width: width of the bounding box, in pixels
-* height: height of the bounding box, in pixels
-"""
 
 DatabaseSample: typing.TypeAlias = (
-    tuple[str, int] | tuple[str, int, list[BoundingBoxAnnotation]]
+    tuple[str, int] | tuple[str, int, tuple[tuple[int, int, int, int, int]]]
 )
 """Type of objects in our JSON representation for this database.
 
@@ -90,7 +173,7 @@ class RawDataLoader(_BaseRawDataLoader):
         return tensor, dict(
             label=sample[1],
             name=sample[0],
-            radsign_bboxes=self.bbox_annotations(sample),
+            bounding_boxes=self.bounding_boxes(sample),
         )
 
     def label(self, sample: DatabaseSample) -> int:
@@ -111,10 +194,8 @@ class RawDataLoader(_BaseRawDataLoader):
         """
         return sample[1]
 
-    def bbox_annotations(
-        self, sample: DatabaseSample
-    ) -> list[BoundingBoxAnnotation]:
-        """Loads a single image sample label from the disk.
+    def bounding_boxes(self, sample: DatabaseSample) -> BoundingBoxes:
+        """Loads image annotated bounding-boxes from the disk.
 
         Parameters
         ----------
@@ -129,7 +210,10 @@ class RawDataLoader(_BaseRawDataLoader):
         -------
             Bounding box annotations, if any available with the sample.
         """
-        return sample[2] if len(sample) > 2 else []  # type: ignore
+        if len(sample) > 2:
+            return BoundingBoxes([BoundingBox(*k) for k in sample[2]])  # type: ignore
+
+        return BoundingBoxes()
 
 
 def make_split(basename: str) -> DatabaseSplit:
diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py
index ae9cca77..9ecf9096 100644
--- a/src/ptbench/engine/saliency/interpretability.py
+++ b/src/ptbench/engine/saliency/interpretability.py
@@ -15,6 +15,8 @@ import torchvision.ops
 
 from tqdm import tqdm
 
+from ...config.data.tbx11k.datamodule import BoundingBox, BoundingBoxes
+
 logger = logging.getLogger(__name__)
 
 
@@ -74,7 +76,7 @@ def _ordered_connected_components(
 
 def _extract_bounding_box(
     mask: numpy.typing.NDArray[numpy.bool_],
-) -> tuple[int, int, int, int]:
+) -> BoundingBox:
     """Defines a bounding box surrounding a connected component mask.
 
     Parameters
@@ -85,23 +87,17 @@ def _extract_bounding_box(
 
     Returns
     -------
-        A tuple of 4 integers representing the bounding box with the following
-        components:
-
-        * top-left horizontal coordinate (``x``) (pixels)
-        * top-left vertical coordinate (``y``) (pixels)
-        * width (pixels)
-        * height (pixels)
+        A bounding box
     """
     x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[
         0
     ]
-    return (int(x), int(y), int(x2 - x + 1), int(y2 - y + 1))
+    return BoundingBox(-1, int(x), int(y), int(x2 - x + 1), int(y2 - y + 1))
 
 
 def _compute_max_iou_and_ioda(
-    detected_box: tuple[int, int, int, int],
-    gt_bboxes: typing.Sequence[tuple[int, int, int, int]],
+    detected_box: BoundingBox,
+    gt_bboxes: BoundingBoxes,
 ) -> tuple[float, float]:
     """Will calculate how much of detected area lies in ground truth boxes.
 
@@ -109,47 +105,35 @@ def _compute_max_iou_and_ioda(
     for each gt box separately and the gt box with the highest
     intersecting part will be used for the calculation.
     """
-    x_L, y_L, w_L, h_L = detected_box
-    detected_area = w_L * h_L
+    detected_area = detected_box.area()
     if detected_area == 0:
-        return 0, 0
+        return 0.0, 0.0
+
     max_intersection = 0
     max_gt_area = 0
 
     for bbox in gt_bboxes:
-        xmin, ymin, width, height = bbox
-
-        gt_area = width * height
-
-        # Calculate intersection
-        xi = max(x_L, xmin)
-        yi = max(y_L, ymin)
-        wi = max(0, min(x_L + w_L, xmin + width) - xi)
-        hi = max(0, min(y_L + h_L, ymin + height) - yi)
-
-        intersection = 0
-        if wi > 0 and hi > 0:
-            intersection = wi * hi
-
+        intersection = bbox.intersection(detected_box)
         if intersection > max_intersection:
             max_intersection = intersection
-            max_gt_area = gt_area
+            max_gt_area = bbox.area()
 
     if max_gt_area == 0 and max_intersection == 0:
         # This case means no intersection was found, even though there are gt boxes
         iou, ioda = 0.0, 0.0
+
     else:
         iou = max_intersection / (
             detected_area + max_gt_area - max_intersection
         )
         ioda = max_intersection / detected_area
 
-    return float(iou), float(ioda)
+    return iou, ioda
 
 
 def _compute_simultaneous_iou_and_ioda(
-    detected_box: tuple[int, int, int, int],
-    gt_bboxes: typing.Sequence[tuple[int, int, int, int]],
+    detected_box: BoundingBox,
+    gt_bboxes: BoundingBoxes,
 ) -> tuple[float, float]:
     """Will calculate how much of detected area lies between ground truth
     boxes.
@@ -158,27 +142,13 @@ def _compute_simultaneous_iou_and_ioda(
     will be compared to them simultaneously (and not to each gt box
     separately).
     """
-    x_L, y_L, w_L, h_L = detected_box
-    detected_area = w_L * h_L
+
+    detected_area = detected_box.area()
     if detected_area == 0:
         return 0, 0
-    intersection = 0
-    total_gt_area = 0
-
-    for bbox in gt_bboxes:
-        xmin, ymin, width, height = bbox
-
-        gt_area = width * height
-        total_gt_area += gt_area
-
-        # Calculate intersection
-        xi = max(x_L, xmin)
-        yi = max(y_L, ymin)
-        wi = max(0, min(x_L + w_L, xmin + width) - xi)
-        hi = max(0, min(y_L + h_L, ymin + height) - yi)
 
-        if wi > 0 and hi > 0:
-            intersection += wi * hi
+    intersection = sum([k.intersection(detected_box) for k in gt_bboxes])
+    total_gt_area = sum([k.area() for k in gt_bboxes])
 
     iou = intersection / (detected_area + total_gt_area - intersection)
     ioda = intersection / detected_area
@@ -251,39 +221,30 @@ def _compute_proportional_energy(
 
 
 def _process_sample(
-    gt_bboxes: typing.Sequence[tuple[int, int, int, int]],
+    gt_bboxes: BoundingBoxes,
     saliency_map: numpy.typing.NDArray[numpy.double],
-) -> tuple[float, float, float, float, int, int, int, int]:
+) -> tuple[float, float, float, float, BoundingBox]:
     """Calculates the metrics for a single sample.
 
     Parameters
     ----------
-
     gt_bboxes
-        A list of ground-truth bounding boxes following the format:
-
-        * xmin: horizontal position of bounding box upper-left corner, in
-          pixels
-        * ymin: vertical position of bounding box upper-left corner, in
-          pixels
-        * width: width of the bounding box, in pixels
-        * height: height of the bounding box, in pixels
+        A list of ground-truth bounding boxes.
 
 
     Returns
     -------
-
         A tuple containing the following values:
 
         * IoU
         * IoDA
         * Proportional energy
         * Average saliency focus
-        * Largest detected bounding box (x, y, width, height)
+        * Largest detected bounding box
     """
 
     masks = _ordered_connected_components(saliency_map)
-    detected_box = (0, 0, 0, 0)
+    detected_box = BoundingBox(-1, 0, 0, 0, 0)
     if masks:
         # we get only the largest bounding box as of now
         detected_box = _extract_bounding_box(masks[0])
@@ -294,15 +255,17 @@ def _process_sample(
     # The binary_mask will be ON/True where the gt boxes are located
     binary_mask = numpy.zeros_like(saliency_map, dtype=numpy.bool_)
     for bbox in gt_bboxes:
-        xmin, ymin, width, height = bbox
-        binary_mask[ymin : ymin + height, xmin : xmin + width] = True
+        binary_mask[
+            bbox.ymin : bbox.ymin + bbox.height,
+            bbox.xmin : bbox.xmin + bbox.width,
+        ] = True
 
     return (
         iou,
         ioda,
         _compute_proportional_energy(saliency_map, binary_mask),
         _compute_avg_saliency_focus(saliency_map, binary_mask),
-        *detected_box,
+        detected_box,
     )
 
 
@@ -358,8 +321,9 @@ def run(
             # TODO: This is very specific to the TBX11k system for labelling
             # regions of interest.  We need to abstract from this to support more
             # datasets and other ways to annotate.
-            bboxes = sample[1].get("radsign_bboxes", [])
-            bboxes = [k[1:] for k in bboxes]  # remove bbox label...
+            bboxes: BoundingBoxes = sample[1].get(
+                "bounding_boxes", BoundingBoxes()
+            )
 
             if label == 0:
                 # we add the entry for dataset completeness
diff --git a/src/ptbench/engine/visualizer.py b/src/ptbench/engine/visualizer.py
index 46f22e02..fed11421 100644
--- a/src/ptbench/engine/visualizer.py
+++ b/src/ptbench/engine/visualizer.py
@@ -103,11 +103,11 @@ def run(
         includes_bboxes = True
         if (
             samples[1]["label"].item() == 0
-            or "radsign_bboxes" not in samples[1]
+            or "bounding_boxes" not in samples[1]
         ):
             includes_bboxes = False
         else:
-            gt_bboxes = samples[1]["radsign_bboxes"]
+            gt_bboxes = samples[1]["bounding_boxes"]
             if not gt_bboxes:
                 includes_bboxes = False
 
diff --git a/tests/test_tbx11k.py b/tests/test_tbx11k.py
index 1b3a7222..a4c31edb 100644
--- a/tests/test_tbx11k.py
+++ b/tests/test_tbx11k.py
@@ -186,10 +186,10 @@ def check_loaded_batch(
         [any([k.startswith(j) for j in prefixes]) for k in batch[1]["name"]]
     )
 
-    assert "radsign_bboxes" in batch[1]
+    assert "bounding_boxes" in batch[1]
 
     for sample, label, bboxes in zip(
-        batch[0], batch[1]["label"], batch[1]["radsign_bboxes"]
+        batch[0], batch[1]["label"], batch[1]["bounding_boxes"]
     ):
         # there must be a sign indicated on the image, if active TB is detected
         if label == 1:
-- 
GitLab