From 1b1382bbf69c96f184b6f925534338a64d28374f Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 25 Jul 2023 17:48:44 +0200 Subject: [PATCH] Evaluation script saves more plots, combines results --- src/ptbench/engine/evaluator.py | 78 ++++++++++++++------------------- src/ptbench/scripts/evaluate.py | 77 +++++++++++++++++++++++++++++--- 2 files changed, 102 insertions(+), 53 deletions(-) diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py index c157d5fe..0d736e1f 100644 --- a/src/ptbench/engine/evaluator.py +++ b/src/ptbench/engine/evaluator.py @@ -186,10 +186,8 @@ def sample_measures_for_threshold( def run( - dataset, name: str, predictions_folder: str, - output_folder: Optional[str | None] = None, f1_thresh: Optional[float] = None, eer_thresh: Optional[float] = None, steps: Optional[int] = 1000, @@ -199,9 +197,6 @@ def run( Parameters --------- - dataset : py:class:`torch.utils.data.Dataset` - a dataset to iterate on - name: the local name of this dataset (e.g. ``train``, or ``test``), to be used when saving measures files. @@ -210,9 +205,6 @@ def run( folder where predictions for the dataset images has been previously stored - output_folder: - folder where to store results. - f1_thresh: This number should come from the training set or a separate validation set. Using a test set value @@ -238,9 +230,7 @@ def run( post_eer_threshold : float Threshold achieving Equal Error Rate for this dataset """ - predictions_path = os.path.join( - predictions_folder, f"predictions_{name}", "predictions.csv" - ) + predictions_path = os.path.join(predictions_folder, f"{name}.csv") if not os.path.exists(predictions_path): predictions_path = predictions_folder @@ -298,12 +288,12 @@ def run( ) data_df = data_df.set_index("index") - # Save evaluation csv + """# Save evaluation csv if output_folder is not None: fullpath = os.path.join(output_folder, f"{name}.csv") logger.info(f"Saving {fullpath}...") os.makedirs(os.path.dirname(fullpath), exist_ok=True) - data_df.to_csv(fullpath) + data_df.to_csv(fullpath)""" # Find max F1 score f1_scores = numpy.asarray(data_df["f1_score"]) @@ -328,42 +318,38 @@ def run( f"threshold {post_eer_threshold:.3f} (chosen *a posteriori*)" ) - # Save score table - if output_folder is not None: - fig, axes = plt.subplots(1) - fig.tight_layout(pad=3.0) + # Generate scores fig + fig_score, axes = plt.subplots(1) + fig_score.tight_layout(pad=3.0) - # Names and bounds - axes.set_xlabel("Score") - axes.set_ylabel("Normalized counts") - axes.set_xlim(0.0, 1.0) + # Names and bounds + axes.set_xlabel("Score") + axes.set_ylabel("Normalized counts") + axes.set_xlim(0.0, 1.0) - neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len( - pred_data["likelihood"] - ) - pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len( - pred_data["likelihood"] - ) - - axes.hist( - [neg_gt["likelihood"], pos_gt["likelihood"]], - weights=[neg_weights, pos_weights], - bins=100, - color=["tab:blue", "tab:orange"], - label=["Negatives", "Positives"], - ) - axes.legend(prop={"size": 10}, loc="upper center") - axes.set_title(f"Score table for {name} subset") + neg_weights = numpy.ones_like(neg_gt["likelihood"]) / len( + pred_data["likelihood"] + ) + pos_weights = numpy.ones_like(pos_gt["likelihood"]) / len( + pred_data["likelihood"] + ) - # we should see some of axes 1 axes - axes.spines["right"].set_visible(False) - axes.spines["top"].set_visible(False) - axes.spines["left"].set_position(("data", -0.015)) + axes.hist( + [neg_gt["likelihood"], pos_gt["likelihood"]], + weights=[neg_weights, pos_weights], + bins=100, + color=["tab:blue", "tab:orange"], + label=["Negatives", "Positives"], + ) + axes.legend(prop={"size": 10}, loc="upper center") + axes.set_title(f"Score table for {name} subset") - fullpath = os.path.join(output_folder, f"{name}_score_table.pdf") - fig.savefig(fullpath) + # we should see some of axes 1 axes + axes.spines["right"].set_visible(False) + axes.spines["top"].set_visible(False) + axes.spines["left"].set_position(("data", -0.015)) - if f1_thresh is not None and eer_thresh is not None: + """if f1_thresh is not None and eer_thresh is not None: # get the closest possible threshold we have index = int(round(steps * f1_thresh)) f1_a_priori = data_df["f1_score"][index] @@ -375,6 +361,6 @@ def run( ) # Print the a priori EER threshold - logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}") + logger.info(f"Equal error rate (chosen *a priori*) {eer_thresh:.3f}")""" - return maxf1_threshold, post_eer_threshold + return pred_data, fig_score, maxf1_threshold, post_eer_threshold diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index c68ee796..9ebc4c0e 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -2,15 +2,21 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import os + +from collections import defaultdict from typing import Union import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup +from matplotlib.backends.backend_pdf import PdfPages from ..data.datamodule import CachingDataModule from ..data.typing import DataLoader +from ..utils.plot import precision_recall_f1iso, roc_curve +from ..utils.table import performance_table logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -117,7 +123,7 @@ def _validate_threshold( "the test set F1-score a priori performance", default=None, show_default=False, - required=False, + required=True, cls=ResourceOption, ) @click.option( @@ -159,8 +165,10 @@ def evaluate( if isinstance(threshold, str): # first run evaluation for reference dataset logger.info(f"Evaluating threshold on '{threshold}' set") - f1_threshold, eer_threshold = run( - _, threshold, predictions_folder, steps=steps + _, _, f1_threshold, eer_threshold = run( + name=threshold, + predictions_folder=predictions_folder, + steps=steps, ) if (f1_threshold is not None) and (eer_threshold is not None): @@ -173,17 +181,72 @@ def evaluate( else: raise ValueError("Threshold value is neither an int nor a float") - for k, v in dataloader.items(): + results_dict = { # type: ignore + "pred_data": defaultdict(dict), + "fig_score": defaultdict(dict), + "maxf1_threshold": defaultdict(dict), + "post_eer_threshold": defaultdict(dict), + } + + for k in dataloader.keys(): if k.startswith("_"): logger.info(f"Skipping dataset '{k}' (not to be evaluated)") continue logger.info(f"Analyzing '{k}' set...") - run( - v, + pred_data, fig_score, maxf1_threshold, post_eer_threshold = run( k, predictions_folder, - output_folder, f1_thresh=f1_threshold, eer_thresh=eer_threshold, steps=steps, ) + + results_dict["pred_data"][k] = pred_data + results_dict["fig_score"][k] = fig_score + results_dict["maxf1_threshold"][k] = maxf1_threshold + results_dict["post_eer_threshold"][k] = post_eer_threshold + + if output_folder is not None: + output_scores = os.path.join(output_folder, "scores.pdf") + + if output_scores is not None: + output_scores = os.path.realpath(output_scores) + logger.info(f"Creating and saving scores at {output_scores}...") + os.makedirs(os.path.dirname(output_scores), exist_ok=True) + + score_pdf = PdfPages(output_scores) + + for fig in results_dict["fig_score"].values(): + score_pdf.savefig(fig) + score_pdf.close() + + data = {} + for subset_name in dataloader.keys(): + data[subset_name] = { + "df": results_dict["pred_data"][subset_name], + "threshold": results_dict["post_eer_threshold"][ # type: ignore + threshold + ].item(), + } + + output_figure = os.path.join(output_folder, "plots.pdf") + + if output_figure is not None: + output_figure = os.path.realpath(output_figure) + logger.info(f"Creating and saving plots at {output_figure}...") + os.makedirs(os.path.dirname(output_figure), exist_ok=True) + pdf = PdfPages(output_figure) + pdf.savefig(precision_recall_f1iso(data)) + pdf.savefig(roc_curve(data)) + pdf.close() + + output_table = os.path.join(output_folder, "table.txt") + logger.info("Tabulating performance summary...") + table = performance_table(data, "rst") + click.echo(table) + if output_table is not None: + output_table = os.path.realpath(output_table) + logger.info(f"Saving table at {output_table}...") + os.makedirs(os.path.dirname(output_table), exist_ok=True) + with open(output_table, "w") as f: + f.write(table) -- GitLab