Skip to content
Snippets Groups Projects
Commit be4a633c authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[engine.saliency.interpretability] Allow user to explicitly define target to be analysed

parent a7d4c64f
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
......@@ -23,14 +23,16 @@ logger = logging.getLogger(__name__)
def _ordered_connected_components(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
threshold: float,
) -> list[numpy.typing.NDArray[numpy.bool_]]:
"""Calculates the largest connected components available on a saliency map
and return those as individual masks.
This implementation is based on [SCORECAM-2020]_:
1. Thresholding: The pixel values above 20% of max value are kept in the
original saliency map. Everything else is set to zero.
1. Thresholding: The pixel values above ``threshold``% of max value are
kept in the original saliency map. Everything else is set to zero. The
value proposed on [SCORECAM-2020]_ is 0.2. Use this value if unsure.
2. The thresholded saliency map is transformed into a boolean array (ones
are attributed to all elements above the threshold.
3. We call :py:func:`skimage.metrics.label` to evaluate all connected
......@@ -44,6 +46,10 @@ def _ordered_connected_components(
saliency_map
Input saliciency map whose connected components will be calculated
from.
threshold
Relative threshold to be used to zero parts of the original saliency
map. A value of 0.2 will zero all values in the saliency map that are
bellow 20% of the maximum value observed in the said map.
Returns
......@@ -54,9 +60,10 @@ def _ordered_connected_components(
"""
# thresholds like [SCORECAM-2020]_
thresholded_mask = (saliency_map >= (0.2 * numpy.max(saliency_map))).astype(
numpy.uint8
)
saliency_array = numpy.array(saliency_map)
thresholded_mask = (
saliency_array >= (threshold * saliency_array.max())
).astype(numpy.uint8)
# avoids an all zeroes mask being processed
if not numpy.any(thresholded_mask):
......@@ -131,6 +138,45 @@ def _compute_max_iou_and_ioda(
return iou, ioda
def get_largest_bounding_boxes(
saliency_map: typing.Sequence[typing.Sequence[float]]
| numpy.typing.NDArray[numpy.double],
n: int,
threshold: float = 0.2,
) -> list[BoundingBox]:
"""Returns the N largest connected components as bounding boxes in a
saliency map.
The return of values is subject to the value of ``threshold`` applied, as
well as on the saliency map itself. The number of objects found is also
affected by those parameters.
saliency_map
Input saliciency map whose connected components will be calculated
from.
n
The number of connected components to search for in the saliency map.
Connected components are then translated to bounding-box notation.
threshold
Relative threshold to be used to zero parts of the original saliency
map. A value of 0.2 will zero all values in the saliency map that are
bellow 20% of the maximum value observed in the said map.
Returns
-------
"""
retval: list[BoundingBox] = []
masks = _ordered_connected_components(saliency_map, threshold)
if masks:
retval += [_extract_bounding_box(k) for k in masks[:n]]
return retval
def _compute_simultaneous_iou_and_ioda(
detected_box: BoundingBox,
gt_bboxes: BoundingBoxes,
......@@ -217,7 +263,7 @@ def _compute_proportional_energy(
if denominator == 0.0:
return 0.0
return float(numpy.sum(saliency_map * gt_mask) / denominator)
return float(numpy.sum(saliency_map * gt_mask) / denominator) # type: ignore
def _process_sample(
......@@ -243,11 +289,10 @@ def _process_sample(
* Largest detected bounding box
"""
masks = _ordered_connected_components(saliency_map)
detected_box = BoundingBox(-1, 0, 0, 0, 0)
if masks:
# we get only the largest bounding box as of now
detected_box = _extract_bounding_box(masks[0])
largest_bbox = get_largest_bounding_boxes(saliency_map, n=1, threshold=0.2)
detected_box = (
largest_bbox[0] if largest_bbox else BoundingBox(-1, 0, 0, 0, 0)
)
# Calculate localization metrics
iou, ioda = _compute_max_iou_and_ioda(detected_box, gt_bboxes)
......@@ -271,6 +316,7 @@ def _process_sample(
def run(
input_folder: pathlib.Path,
target_label: int,
datamodule: lightning.pytorch.LightningDataModule,
) -> dict[str, list[typing.Any]]:
"""Applies visualization techniques on input CXR, outputs images with
......@@ -281,6 +327,9 @@ def run(
input_folder
Directory in which the saliency maps are stored for a specific
visualization type.
target_label
The label to target for evaluating interpretability metrics. Samples
contining any other label are ignored.
datamodule
The lightning datamodule to iterate on.
......@@ -318,6 +367,12 @@ def run(
name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item())
if label != target_label:
# we add the entry for dataset completeness, but do not treat
# it
retval[dataset_name].append([name, label])
continue
# TODO: This is very specific to the TBX11k system for labelling
# regions of interest. We need to abstract from this to support more
# datasets and other ways to annotate.
......@@ -325,11 +380,6 @@ def run(
"bounding_boxes", BoundingBoxes()
)
if label == 0:
# we add the entry for dataset completeness
retval[dataset_name].append([name, label])
continue
if not bboxes:
logger.warning(
f"Sample `{name}` does not contdain bounding-box information. "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment