Skip to content
Snippets Groups Projects
Commit eff1554d authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Evaluation table shows if threshold chosen a priori or posteriori

parent b82e7485
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -48,7 +48,7 @@ def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float: ...@@ -48,7 +48,7 @@ def eer_threshold(neg: Iterable[float], pos: Iterable[float]) -> float:
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1) fpr, tpr, thresholds = metrics.roc_curve(y_true, y_predictions, pos_label=1)
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
return interp1d(fpr, thresholds)(eer) return float(interp1d(fpr, thresholds)(eer))
def posneg( def posneg(
...@@ -198,12 +198,11 @@ def run( ...@@ -198,12 +198,11 @@ def run(
--------- ---------
name: name:
the local name of this dataset (e.g. ``train``, or ``test``), to be The name of subset to load.
used when saving measures files.
predictions_folder: predictions_folder:
folder where predictions for the dataset images has been previously Folder where predictions for the dataset images has been previously
stored stored.
f1_thresh: f1_thresh:
This number should come from This number should come from
...@@ -224,11 +223,17 @@ def run( ...@@ -224,11 +223,17 @@ def run(
Returns Returns
------- -------
maxf1_threshold : float pred_data:
Threshold to achieve the highest possible F1-score for this dataset The loaded predictions for the specified subset.
post_eer_threshold : float fig_scores:
Threshold achieving Equal Error Rate for this dataset Figure of the histogram distributions of true-positive/true-negative scores.
maxf1_threshold:
Threshold to achieve the highest possible F1-score for this dataset.
post_eer_threshold:
Threshold achieving Equal Error Rate for this dataset.
""" """
predictions_path = os.path.join(predictions_folder, f"{name}.csv") predictions_path = os.path.join(predictions_folder, f"{name}.csv")
......
...@@ -224,9 +224,12 @@ def evaluate( ...@@ -224,9 +224,12 @@ def evaluate(
for subset_name in dataloader.keys(): for subset_name in dataloader.keys():
data[subset_name] = { data[subset_name] = {
"df": results_dict["pred_data"][subset_name], "df": results_dict["pred_data"][subset_name],
"threshold": results_dict["post_eer_threshold"][ # type: ignore "threshold": results_dict["post_eer_threshold"][threshold]
threshold if isinstance(threshold, str)
].item(), else eer_threshold,
"threshold_type": f"posteriori [{threshold}]"
if isinstance(threshold, str)
else "priori",
} }
output_figure = os.path.join(output_folder, "plots.pdf") output_figure = os.path.join(output_folder, "plots.pdf")
......
...@@ -47,6 +47,7 @@ def performance_table(data, fmt): ...@@ -47,6 +47,7 @@ def performance_table(data, fmt):
headers = [ headers = [
"Dataset", "Dataset",
"T", "T",
"T Type",
"F1 (95% CI)", "F1 (95% CI)",
"Prec (95% CI)", "Prec (95% CI)",
"Recall/Sen (95% CI)", "Recall/Sen (95% CI)",
...@@ -61,6 +62,7 @@ def performance_table(data, fmt): ...@@ -61,6 +62,7 @@ def performance_table(data, fmt):
entry = [ entry = [
k, k,
v["threshold"], v["threshold"],
v["threshold_type"],
] ]
df = v["df"] df = v["df"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment