From eff1554d2193433c612cdb842460b8cac1188d66 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 26 Jul 2023 10:56:44 +0200 Subject: [PATCH] Evaluation table shows if threshold chosen a priori or posteriori --- src/ptbench/engine/evaluator.py | 23 ++++++++++++++--------- src/ptbench/scripts/evaluate.py | 9 ++++++--- src/ptbench/utils/table.py | 2 ++ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py index 0d736e1f..a551dba6 100644 --- a/src/ptbench/engine/evaluator.py +++ b/src/ptbench/engine/evaluator.py @@ -48,7 +48,7 @@ def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float: fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1) eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) - return interp1d(fpr, thresholds)(eer) + return float(interp1d(fpr, thresholds)(eer)) def posneg( @@ -198,12 +198,11 @@ def run( --------- name: - the local name of this dataset (e.g. ``train``, or ``test``), to be - used when saving measures files. + The name of subset to load. predictions_folder: - folder where predictions for the dataset images has been previously - stored + Folder where predictions for the dataset images has been previously + stored. f1_thresh: This number should come from @@ -224,11 +223,17 @@ def run( Returns ------- - maxf1_threshold : float - Threshold to achieve the highest possible F1-score for this dataset + pred_data: + The loaded predictions for the specified subset. - post_eer_threshold : float - Threshold achieving Equal Error Rate for this dataset + fig_scores: + Figure of the histogram distributions of true-positive/true-negative scores. + + maxf1_threshold: + Threshold to achieve the highest possible F1-score for this dataset. + + post_eer_threshold: + Threshold achieving Equal Error Rate for this dataset. """ predictions_path = os.path.join(predictions_folder, f"{name}.csv") diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index 9ebc4c0e..08be4ea6 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -224,9 +224,12 @@ def evaluate( for subset_name in dataloader.keys(): data[subset_name] = { "df": results_dict["pred_data"][subset_name], - "threshold": results_dict["post_eer_threshold"][ # type: ignore - threshold - ].item(), + "threshold": results_dict["post_eer_threshold"][threshold] + if isinstance(threshold, str) + else eer_threshold, + "threshold_type": f"posteriori [{threshold}]" + if isinstance(threshold, str) + else "priori", } output_figure = os.path.join(output_folder, "plots.pdf") diff --git a/src/ptbench/utils/table.py b/src/ptbench/utils/table.py index cb4594b9..c9d35988 100644 --- a/src/ptbench/utils/table.py +++ b/src/ptbench/utils/table.py @@ -47,6 +47,7 @@ def performance_table(data, fmt): headers = [ "Dataset", "T", + "T Type", "F1 (95% CI)", "Prec (95% CI)", "Recall/Sen (95% CI)", @@ -61,6 +62,7 @@ def performance_table(data, fmt): entry = [ k, v["threshold"], + v["threshold_type"], ] df = v["df"] -- GitLab