diff --git a/src/ptbench/engine/evaluator.py b/src/ptbench/engine/evaluator.py index 0d736e1faa2f8f46ef835d86bd39d952f62c7f20..a551dba65f4ed0deaab1b4a5ef541accaf79d454 100644 --- a/src/ptbench/engine/evaluator.py +++ b/src/ptbench/engine/evaluator.py @@ -48,7 +48,7 @@ def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float: fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1) eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) - return interp1d(fpr, thresholds)(eer) + return float(interp1d(fpr, thresholds)(eer)) def posneg( @@ -198,12 +198,11 @@ def run( --------- name: - the local name of this dataset (e.g. ``train``, or ``test``), to be - used when saving measures files. + The name of subset to load. predictions_folder: - folder where predictions for the dataset images has been previously - stored + Folder where predictions for the dataset images has been previously + stored. f1_thresh: This number should come from @@ -224,11 +223,17 @@ def run( Returns ------- - maxf1_threshold : float - Threshold to achieve the highest possible F1-score for this dataset + pred_data: + The loaded predictions for the specified subset. - post_eer_threshold : float - Threshold achieving Equal Error Rate for this dataset + fig_scores: + Figure of the histogram distributions of true-positive/true-negative scores. + + maxf1_threshold: + Threshold to achieve the highest possible F1-score for this dataset. + + post_eer_threshold: + Threshold achieving Equal Error Rate for this dataset. """ predictions_path = os.path.join(predictions_folder, f"{name}.csv") diff --git a/src/ptbench/scripts/evaluate.py b/src/ptbench/scripts/evaluate.py index 9ebc4c0ee8a845ee9856a31c887a124950511de8..08be4ea68afa5c73481cf29a1a593bb81eeb373b 100644 --- a/src/ptbench/scripts/evaluate.py +++ b/src/ptbench/scripts/evaluate.py @@ -224,9 +224,12 @@ def evaluate( for subset_name in dataloader.keys(): data[subset_name] = { "df": results_dict["pred_data"][subset_name], - "threshold": results_dict["post_eer_threshold"][ # type: ignore - threshold - ].item(), + "threshold": results_dict["post_eer_threshold"][threshold] + if isinstance(threshold, str) + else eer_threshold, + "threshold_type": f"posteriori [{threshold}]" + if isinstance(threshold, str) + else "priori", } output_figure = os.path.join(output_folder, "plots.pdf") diff --git a/src/ptbench/utils/table.py b/src/ptbench/utils/table.py index cb4594b99a9a242bcbb45669ba344f96e8d5b418..c9d35988978df6f4a9a7721e399c27be1d2e8d68 100644 --- a/src/ptbench/utils/table.py +++ b/src/ptbench/utils/table.py @@ -47,6 +47,7 @@ def performance_table(data, fmt): headers = [ "Dataset", "T", + "T Type", "F1 (95% CI)", "Prec (95% CI)", "Recall/Sen (95% CI)", @@ -61,6 +62,7 @@ def performance_table(data, fmt): entry = [ k, v["threshold"], + v["threshold_type"], ] df = v["df"]