From 22468a8ab9f560ef3fa4eb335d161753d29667a7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 22 Jan 2024 11:36:21 +0100
Subject: [PATCH] [doc] Add missing docstring parameters and returns

---
 src/ptbench/engine/evaluator.py               |  4 +
 src/ptbench/engine/saliency/completeness.py   | 23 +++++-
 src/ptbench/engine/saliency/generator.py      | 13 +++-
 .../engine/saliency/interpretability.py       | 74 ++++++++++++++-----
 4 files changed, 92 insertions(+), 22 deletions(-)

diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py
index 8c1a8d84..829b9a6b 100644
--- a/src/ptbench/engine/evaluator.py
+++ b/src/ptbench/engine/evaluator.py
@@ -318,6 +318,8 @@ def aggregate_roc(
         A dictionary mapping split names to ROC curve data produced by
         :py:func:sklearn.metrics.roc_curve`.
 
+    title
+        The title of the plot.
 
     Returns
     -------
@@ -471,6 +473,8 @@ def aggregate_pr(
     data
         A dictionary mapping split names to ROC curve data produced by
         :py:func:sklearn.metrics.precision_recall_curve`.
+    title
+        The title of the plot.
 
 
     Returns
diff --git a/src/ptbench/engine/saliency/completeness.py b/src/ptbench/engine/saliency/completeness.py
index 2d42f21a..c711c7b3 100644
--- a/src/ptbench/engine/saliency/completeness.py
+++ b/src/ptbench/engine/saliency/completeness.py
@@ -123,7 +123,28 @@ def _process_sample(
     percentiles: typing.Sequence[int],
 ) -> list:
     """Helper function to :py:func:`run` to be used in multiprocessing
-    contexts."""
+    contexts.
+
+    Parameters
+    ----------
+    model
+        Neural network model (e.g. pasa).
+    device
+        The device to process samples on.
+    saliency_map_callable
+        A callable saliency-map generator from grad-cam
+    target_class
+        Class to target for saliency estimation. Can be either set to
+        "all" or "highest". "highest".
+    positive only
+        If set, and the model chosen has a single output (binary), then
+        saliency maps will only be generated for samples of the positive class
+
+    percentiles
+        A sequence of percentiles (percent x100) integer values indicating the
+        proportion of pixels to perturb in the original image to calculate both
+        MoRF and LeRF scores.
+    """
 
     name: str = sample[1]["name"][0]
     label: int = int(sample[1]["label"].item())
diff --git a/src/ptbench/engine/saliency/generator.py b/src/ptbench/engine/saliency/generator.py
index c8fa1d46..cb101e98 100644
--- a/src/ptbench/engine/saliency/generator.py
+++ b/src/ptbench/engine/saliency/generator.py
@@ -90,7 +90,18 @@ def _create_saliency_map_callable(
 def _save_saliency_map(
     output_folder: pathlib.Path, name: str, saliency_map: torch.Tensor
 ) -> None:
-    """Helper function to save a saliency map to disk."""
+    """Helper function to save a saliency map to disk.
+
+    Parameters
+    ---------
+    output_folder
+        Directory in which the resulting saliency maps will be saved.
+    name
+        Name of the saved file.
+    saliency_map
+        A real-valued saliency-map that conveys regions used for
+        classification in the original sample.
+    """
 
     n = pathlib.Path(name)
     (output_folder / n.parent).mkdir(parents=True, exist_ok=True)
diff --git a/src/ptbench/engine/saliency/interpretability.py b/src/ptbench/engine/saliency/interpretability.py
index dba6eee3..89873d0e 100644
--- a/src/ptbench/engine/saliency/interpretability.py
+++ b/src/ptbench/engine/saliency/interpretability.py
@@ -88,7 +88,7 @@ def _extract_bounding_box(
 
     Returns
     -------
-        A bounding box
+        A bounding box.
     """
     x, y, x2, y2 = torchvision.ops.masks_to_boxes(torch.tensor(mask)[None, :])[
         0
@@ -105,6 +105,20 @@ def _compute_max_iou_and_ioda(
     If there are multiple gt boxes, the detected area will be calculated
     for each gt box separately and the gt box with the highest
     intersecting part will be used for the calculation.
+
+
+    Parameters
+    ----------
+    detected_box
+        BoundingBox of the detected area.
+    gt_bboxes
+        Ground-truth bounding boxes in the format ``(x, y, width,
+        height)``.
+
+
+    Returns
+    -------
+        The max iou and ioda values.
     """
     detected_area = detected_box.area()
     if detected_area == 0:
@@ -146,6 +160,8 @@ def _get_largest_bounding_boxes(
     affected by those parameters.
 
 
+    Parameters
+    ----------
     saliency_map
         Input saliciency map whose connected components will be calculated
         from.
@@ -160,6 +176,7 @@ def _get_largest_bounding_boxes(
 
     Returns
     -------
+        The N largest connected components as bounding boxes in a saliency map.
     """
 
     retval: list[BoundingBox] = []
@@ -181,6 +198,20 @@ def _compute_simultaneous_iou_and_ioda(
     This means that if there are multiple gt boxes, the detected area
     will be compared to them simultaneously (and not to each gt box
     separately).
+
+
+    Parameters
+    ----------
+    detected_box
+            BoundingBox of the detected area.
+    gt_bboxes
+            Collection of bounding boxes of the ground-truth drawn as
+            ``True`` values.
+
+
+    Returns
+    -------
+        The iou and ioda for the provided boxes.
     """
 
     detected_area = detected_box.area()
@@ -210,12 +241,12 @@ def _compute_avg_saliency_focus(
 
     Parameters
     ----------
-        gt_bboxes
-            Ground-truth bounding boxes in the format ``(x, y, width,
-            height)``.
-        gt_mask
-            Ground-truth mask containing the bounding boxes of the ground-truth
-            drawn as ``True`` values.
+    saliency_map
+        A real-valued saliency-map that conveys regions used for
+        classification in the original sample.
+    gt_mask
+        Ground-truth mask containing the bounding boxes of the ground-truth
+        drawn as ``True`` values.
 
 
     Returns
@@ -239,12 +270,12 @@ def _compute_proportional_energy(
 
     Parameters
     ----------
-        saliency_map
-            A real-valued saliency-map that conveys regions used for
-            classification in the original sample.
-        gt_mask
-            Ground-truth mask containing the bounding boxes of the ground-truth
-            drawn as ``True`` values.
+    saliency_map
+        A real-valued saliency-map that conveys regions used for
+        classification in the original sample.
+    gt_mask
+        Ground-truth mask containing the bounding boxes of the ground-truth
+        drawn as ``True`` values.
 
 
     Returns
@@ -268,15 +299,16 @@ def _compute_binary_mask(
 
     The binary_mask will be ON/True where the gt boxes are located.
 
+
     Parameters
     ----------
-        gt_bboxes
-            Ground-truth bounding boxes in the format ``(x, y, width,
-            height)``.
+    gt_bboxes
+        Ground-truth bounding boxes in the format ``(x, y, width,
+        height)``.
 
-        saliency_map
-            A real-valued saliency-map that conveys regions used for
-            classification in the original sample.
+    saliency_map
+        A real-valued saliency-map that conveys regions used for
+        classification in the original sample.
 
 
     Returns
@@ -305,6 +337,9 @@ def _process_sample(
     ----------
     gt_bboxes
         A list of ground-truth bounding boxes.
+    saliency_map
+        A real-valued saliency-map that conveys regions used for
+        classification in the original sample.
 
 
     Returns
@@ -364,7 +399,6 @@ def run(
 
     Returns
     -------
-
         A dictionary where keys are dataset names in the provide datamodule,
         and values are lists containing sample information alongside metrics
         calculated:
-- 
GitLab