From 1b1382bbf69c96f184b6f925534338a64d28374f Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 25 Jul 2023 17:48:44 +0200
Subject: [PATCH] Evaluation script saves more plots, combines results

---
 src/ptbench/engine/evaluator.py | 78 ++++++++++++++-------------------
 src/ptbench/scripts/evaluate.py | 77 +++++++++++++++++++++++++++++---
 2 files changed, 102 insertions(+), 53 deletions(-)

diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py
index c157d5fe..0d736e1f 100644
--- a/src/ptbench/engine/evaluator.py
+++ b/src/ptbench/engine/evaluator.py
@@ -186,10 +186,8 @@ def sample_measures_for_threshold(
 
 
 def run(
-    dataset,
     name: str,
     predictions_folder: str,
-    output_folder: Optional[str | None] = None,
     f1_thresh: Optional[float] = None,
     eer_thresh: Optional[float] = None,
     steps: Optional[int] = 1000,
@@ -199,9 +197,6 @@ def run(
     Parameters
     ---------
 
-    dataset : py:class:`torch.utils.data.Dataset`
-        a dataset to iterate on
-
     name:
         the local name of this dataset (e.g. ``train``, or ``test``), to be
         used when saving measures files.
@@ -210,9 +205,6 @@ def run(
         folder where predictions for the dataset images has been previously
         stored
 
-    output_folder:
-        folder where to store results.
-
     f1_thresh:
         This number should come from
         the training set or a separate validation set.  Using a test set value
@@ -238,9 +230,7 @@ def run(
     post_eer_threshold : float
         Threshold achieving Equal Error Rate for this dataset
     """
-    predictions_path = os.path.join(
-        predictions_folder, f"predictions_{name}", "predictions.csv"
-    )
+    predictions_path = os.path.join(predictions_folder, f"{name}.csv")
 
     if not os.path.exists(predictions_path):
         predictions_path = predictions_folder
@@ -298,12 +288,12 @@ def run(
     )
     data_df = data_df.set_index("index")
 
-    # Save evaluation csv
+    """# Save evaluation csv
     if output_folder is not None:
         fullpath = os.path.join(output_folder, f"{name}.csv")
         logger.info(f"Saving {fullpath}...")
         os.makedirs(os.path.dirname(fullpath), exist_ok=True)
-        data_df.to_csv(fullpath)
+        data_df.to_csv(fullpath)"""
 
     # Find max F1 score
     f1_scores = numpy.asarray(data_df["f1_score"])
@@ -328,42 +318,38 @@ def run(
         f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)"
     )
 
-    # Save score table
-    if output_folder is not None:
-        fig, axes = plt.subplots(1)
-        fig.tight_layout(pad=3.0)
+    # Generate scores fig
+    fig_score, axes = plt.subplots(1)
+    fig_score.tight_layout(pad=3.0)
 
-        # Names and bounds
-        axes.set_xlabel("Score")
-        axes.set_ylabel("Normalized counts")
-        axes.set_xlim(0.0, 1.0)
+    # Names and bounds
+    axes.set_xlabel("Score")
+    axes.set_ylabel("Normalized counts")
+    axes.set_xlim(0.0, 1.0)
 
-        neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len(
-            pred_data["likelihood"]
-        )
-        pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len(
-            pred_data["likelihood"]
-        )
-
-        axes.hist(
-            [neg_gt["likelihood"], pos_gt["likelihood"]],
-            weights=[neg_weights, pos_weights],
-            bins=100,
-            color=["tab:blue", "tab:orange"],
-            label=["Negatives", "Positives"],
-        )
-        axes.legend(prop={"size": 10}, loc="upper center")
-        axes.set_title(f"Score table for {name} subset")
+    neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len(
+        pred_data["likelihood"]
+    )
+    pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len(
+        pred_data["likelihood"]
+    )
 
-        # we should see some of axes 1 axes
-        axes.spines["right"].set_visible(False)
-        axes.spines["top"].set_visible(False)
-        axes.spines["left"].set_position(("data", -0.015))
+    axes.hist(
+        [neg_gt["likelihood"], pos_gt["likelihood"]],
+        weights=[neg_weights, pos_weights],
+        bins=100,
+        color=["tab:blue", "tab:orange"],
+        label=["Negatives", "Positives"],
+    )
+    axes.legend(prop={"size": 10}, loc="upper center")
+    axes.set_title(f"Score table for {name} subset")
 
-        fullpath = os.path.join(output_folder, f"{name}_score_table.pdf")
-        fig.savefig(fullpath)
+    # we should see some of axes 1 axes
+    axes.spines["right"].set_visible(False)
+    axes.spines["top"].set_visible(False)
+    axes.spines["left"].set_position(("data", -0.015))
 
-    if f1_thresh is not None and eer_thresh is not None:
+    """if f1_thresh is not None and eer_thresh is not None:
         # get the closest possible threshold we have
         index = int(round(steps * f1_thresh))
         f1_a_priori = data_df["f1_score"][index]
@@ -375,6 +361,6 @@ def run(
         )
 
         # Print the a priori EER threshold
-        logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")
+        logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")"""
 
-    return maxf1_threshold, post_eer_threshold
+    return pred_data, fig_score, maxf1_threshold, post_eer_threshold
diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py
index c68ee796..9ebc4c0e 100644
--- a/src/ptbench/scripts/evaluate.py
+++ b/src/ptbench/scripts/evaluate.py
@@ -2,15 +2,21 @@
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
 
+import os
+
+from collections import defaultdict
 from typing import Union
 
 import click
 
 from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
+from matplotlib.backends.backend_pdf import PdfPages
 
 from ..data.datamodule import CachingDataModule
 from ..data.typing import DataLoader
+from ..utils.plot import precision_recall_f1iso, roc_curve
+from ..utils.table import performance_table
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
@@ -117,7 +123,7 @@ def _validate_threshold(
     "the test set F1-score a priori performance",
     default=None,
     show_default=False,
-    required=False,
+    required=True,
     cls=ResourceOption,
 )
 @click.option(
@@ -159,8 +165,10 @@ def evaluate(
     if isinstance(threshold, str):
         # first run evaluation for reference dataset
         logger.info(f"Evaluating threshold on '{threshold}' set")
-        f1_threshold, eer_threshold = run(
-            _, threshold, predictions_folder, steps=steps
+        _, _, f1_threshold, eer_threshold = run(
+            name=threshold,
+            predictions_folder=predictions_folder,
+            steps=steps,
         )
 
         if (f1_threshold is not None) and (eer_threshold is not None):
@@ -173,17 +181,72 @@ def evaluate(
     else:
         raise ValueError("Threshold value is neither an int nor a float")
 
-    for k, v in dataloader.items():
+    results_dict = {  # type: ignore
+        "pred_data": defaultdict(dict),
+        "fig_score": defaultdict(dict),
+        "maxf1_threshold": defaultdict(dict),
+        "post_eer_threshold": defaultdict(dict),
+    }
+
+    for k in dataloader.keys():
         if k.startswith("_"):
             logger.info(f"Skipping dataset '{k}' (not to be evaluated)")
             continue
         logger.info(f"Analyzing '{k}' set...")
-        run(
-            v,
+        pred_data, fig_score, maxf1_threshold, post_eer_threshold = run(
             k,
             predictions_folder,
-            output_folder,
             f1_thresh=f1_threshold,
             eer_thresh=eer_threshold,
             steps=steps,
         )
+
+        results_dict["pred_data"][k] = pred_data
+        results_dict["fig_score"][k] = fig_score
+        results_dict["maxf1_threshold"][k] = maxf1_threshold
+        results_dict["post_eer_threshold"][k] = post_eer_threshold
+
+    if output_folder is not None:
+        output_scores = os.path.join(output_folder, "scores.pdf")
+
+        if output_scores is not None:
+            output_scores = os.path.realpath(output_scores)
+            logger.info(f"Creating and saving scores at {output_scores}...")
+            os.makedirs(os.path.dirname(output_scores), exist_ok=True)
+
+            score_pdf = PdfPages(output_scores)
+
+            for fig in results_dict["fig_score"].values():
+                score_pdf.savefig(fig)
+            score_pdf.close()
+
+        data = {}
+        for subset_name in dataloader.keys():
+            data[subset_name] = {
+                "df": results_dict["pred_data"][subset_name],
+                "threshold": results_dict["post_eer_threshold"][  # type: ignore
+                    threshold
+                ].item(),
+            }
+
+        output_figure = os.path.join(output_folder, "plots.pdf")
+
+        if output_figure is not None:
+            output_figure = os.path.realpath(output_figure)
+            logger.info(f"Creating and saving plots at {output_figure}...")
+            os.makedirs(os.path.dirname(output_figure), exist_ok=True)
+            pdf = PdfPages(output_figure)
+            pdf.savefig(precision_recall_f1iso(data))
+            pdf.savefig(roc_curve(data))
+            pdf.close()
+
+        output_table = os.path.join(output_folder, "table.txt")
+        logger.info("Tabulating performance summary...")
+        table = performance_table(data, "rst")
+        click.echo(table)
+        if output_table is not None:
+            output_table = os.path.realpath(output_table)
+            logger.info(f"Saving table at {output_table}...")
+            os.makedirs(os.path.dirname(output_table), exist_ok=True)
+            with open(output_table, "w") as f:
+                f.write(table)
-- 
GitLab