diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py index 8c1a8d843b542b8c8ecefb4496614d8df1a5ca62..829b9a6b04db10beb414bba5ed04aadb24b472b9 100644 --- a/src/ptbench/engine/evaluator.py +++ b/src/ptbench/engine/evaluator.py @@ -318,6 +318,8 @@ def aggregate_roc( A dictionary mapping split names to ROC curve data produced by :py:func:sklearn.metrics.roc_curve`. + title + The title of the plot. Returns ------- @@ -471,6 +473,8 @@ def aggregate_pr( data A dictionary mapping split names to ROC curve data produced by :py:func:sklearn.metrics.precision_recall_curve`. + title + The title of the plot. Returns diff --git a/src/ptbench/engine/saliency/completeness.py b/src/ptbench/engine/saliency/completeness.py index 2d42f21a7fbe84cc60faddf6a4923db8f7218292..c711c7b3272ddeeda4e645c4d436a4a18376610d 100644 --- a/src/ptbench/engine/saliency/completeness.py +++ b/src/ptbench/engine/saliency/completeness.py @@ -123,7 +123,28 @@ def _process_sample( percentiles: typing.Sequence[int], ) -> list: """Helper function to :py:func:`run` to be used in multiprocessing - contexts.""" + contexts. + + Parameters + ---------- + model + Neural network model (e.g. pasa). + device + The device to process samples on. + saliency_map_callable + A callable saliency-map generator from grad-cam + target_class + Class to target for saliency estimation. Can be either set to + "all" or "highest". "highest". + positive only + If set, and the model chosen has a single output (binary), then + saliency maps will only be generated for samples of the positive class + + percentiles + A sequence of percentiles (percent x100) integer values indicating the + proportion of pixels to perturb in the original image to calculate both + MoRF and LeRF scores. + """ name: str = sample[1]["name"][0] label: int = int(sample[1]["label"].item()) diff --git a/src/ptbench/engine/saliency/generator.py b/src/ptbench/engine/saliency/generator.py index c8fa1d4614541e351af67ab341cf71bc63731b5b..cb101e98ffeca29ad9c7d8a1e5cae7492f01c381 100644 --- a/src/ptbench/engine/saliency/generator.py +++ b/src/ptbench/engine/saliency/generator.py @@ -90,7 +90,18 @@ def _create_saliency_map_callable( def _save_saliency_map( output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor ) -> None: - """Helper function to save a saliency map to disk.""" + """Helper function to save a saliency map to disk. + + Parameters + --------- + output_folder + Directory in which the resulting saliency maps will be saved. + name + Name of the saved file. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + """ n = pathlib.Path(name) (output_folder / n.parent).mkdir(parents=True, exist_ok=True) diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py index dba6eee30efb2d1d431a4c7ac5274f659ab8a902..89873d0e7125ef15716d062ef13413da1dc4bb6c 100644 --- a/src/ptbench/engine/saliency/interpretability.py +++ b/src/ptbench/engine/saliency/interpretability.py @@ -88,7 +88,7 @@ def _extract_bounding_box( Returns ------- - A bounding box + A bounding box. """ x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[ 0 @@ -105,6 +105,20 @@ def _compute_max_iou_and_ioda( If there are multiple gt boxes, the detected area will be calculated for each gt box separately and the gt box with the highest intersecting part will be used for the calculation. + + + Parameters + ---------- + detected_box + BoundingBox of the detected area. + gt_bboxes + Ground-truth bounding boxes in the format ``(x, y, width, + height)``. + + + Returns + ------- + The max iou and ioda values. """ detected_area = detected_box.area() if detected_area == 0: @@ -146,6 +160,8 @@ def _get_largest_bounding_boxes( affected by those parameters. + Parameters + ---------- saliency_map Input saliciency map whose connected components will be calculated from. @@ -160,6 +176,7 @@ def _get_largest_bounding_boxes( Returns ------- + The N largest connected components as bounding boxes in a saliency map. """ retval: list[BoundingBox] = [] @@ -181,6 +198,20 @@ def _compute_simultaneous_iou_and_ioda( This means that if there are multiple gt boxes, the detected area will be compared to them simultaneously (and not to each gt box separately). + + + Parameters + ---------- + detected_box + BoundingBox of the detected area. + gt_bboxes + Collection of bounding boxes of the ground-truth drawn as + ``True`` values. + + + Returns + ------- + The iou and ioda for the provided boxes. """ detected_area = detected_box.area() @@ -210,12 +241,12 @@ def _compute_avg_saliency_focus( Parameters ---------- - gt_bboxes - Ground-truth bounding boxes in the format ``(x, y, width, - height)``. - gt_mask - Ground-truth mask containing the bounding boxes of the ground-truth - drawn as ``True`` values. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + gt_mask + Ground-truth mask containing the bounding boxes of the ground-truth + drawn as ``True`` values. Returns @@ -239,12 +270,12 @@ def _compute_proportional_energy( Parameters ---------- - saliency_map - A real-valued saliency-map that conveys regions used for - classification in the original sample. - gt_mask - Ground-truth mask containing the bounding boxes of the ground-truth - drawn as ``True`` values. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. + gt_mask + Ground-truth mask containing the bounding boxes of the ground-truth + drawn as ``True`` values. Returns @@ -268,15 +299,16 @@ def _compute_binary_mask( The binary_mask will be ON/True where the gt boxes are located. + Parameters ---------- - gt_bboxes - Ground-truth bounding boxes in the format ``(x, y, width, - height)``. + gt_bboxes + Ground-truth bounding boxes in the format ``(x, y, width, + height)``. - saliency_map - A real-valued saliency-map that conveys regions used for - classification in the original sample. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. Returns @@ -305,6 +337,9 @@ def _process_sample( ---------- gt_bboxes A list of ground-truth bounding boxes. + saliency_map + A real-valued saliency-map that conveys regions used for + classification in the original sample. Returns @@ -364,7 +399,6 @@ def run( Returns ------- - A dictionary where keys are dataset names in the provide datamodule, and values are lists containing sample information alongside metrics calculated: