Skip to content
Snippets Groups Projects
Commit eff1554d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Evaluation table shows if threshold chosen a priori or posteriori

parent b82e7485
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -48,7 +48,7 @@ def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1)
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
return interp1d(fpr, thresholds)(eer)
return float(interp1d(fpr, thresholds)(eer))
def posneg(
......@@ -198,12 +198,11 @@ def run(
---------
name:
the local name of this dataset (e.g. ``train``, or ``test``), to be
used when saving measures files.
The name of subset to load.
predictions_folder:
folder where predictions for the dataset images has been previously
stored
Folder where predictions for the dataset images has been previously
stored.
f1_thresh:
This number should come from
......@@ -224,11 +223,17 @@ def run(
Returns
-------
maxf1_threshold : float
Threshold to achieve the highest possible F1-score for this dataset
pred_data:
The loaded predictions for the specified subset.
post_eer_threshold : float
Threshold achieving Equal Error Rate for this dataset
fig_scores:
Figure of the histogram distributions of true-positive/true-negative scores.
maxf1_threshold:
Threshold to achieve the highest possible F1-score for this dataset.
post_eer_threshold:
Threshold achieving Equal Error Rate for this dataset.
"""
predictions_path = os.path.join(predictions_folder, f"{name}.csv")
......
......@@ -224,9 +224,12 @@ def evaluate(
for subset_name in dataloader.keys():
data[subset_name] = {
"df": results_dict["pred_data"][subset_name],
"threshold": results_dict["post_eer_threshold"][ # type: ignore
threshold
].item(),
"threshold": results_dict["post_eer_threshold"][threshold]
if isinstance(threshold, str)
else eer_threshold,
"threshold_type": f"posteriori [{threshold}]"
if isinstance(threshold, str)
else "priori",
}
output_figure = os.path.join(output_folder, "plots.pdf")
......
......@@ -47,6 +47,7 @@ def performance_table(data, fmt):
headers = [
"Dataset",
"T",
"T Type",
"F1 (95% CI)",
"Prec (95% CI)",
"Recall/Sen (95% CI)",
......@@ -61,6 +62,7 @@ def performance_table(data, fmt):
entry = [
k,
v["threshold"],
v["threshold_type"],
]
df = v["df"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment