Skip to content
Snippets Groups Projects
Commit 4b010f32 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[scripts,engine] Implement fixes on evaluation (closes #20), and prepare for...

[scripts,engine] Implement fixes on evaluation (closes #20), and prepare for MLflow integration (issue #60)

* Use average-precision instead of AUC for precision-recall figures
* Report average precision on tables
* Include dependence on credible (prepare to handle issue #43)
* Also saves result table in JSON format (prepare to handle issue #60)
parent 5f2b57a9
No related branches found
No related tags found
1 merge request!24Implement fixes on evaluation (closes #20), and prepare for handling issue #60
......@@ -25,6 +25,7 @@ requirements:
- pip
- clapper {{ clapper }}
- click {{ click }}
- credible {{ credible }}
- grad-cam {{ grad_cam }}
- matplotlib {{ matplotlib }}
- numpy {{ numpy }}
......@@ -44,6 +45,7 @@ requirements:
- python >=3.10
- {{ pin_compatible('clapper') }}
- {{ pin_compatible('click') }}
- {{ pin_compatible('credible') }}
- {{ pin_compatible('grad-cam', max_pin='x.x') }}
- {{ pin_compatible('matplotlib') }}
- {{ pin_compatible('numpy') }}
......
......@@ -29,6 +29,7 @@ classifiers = [
dependencies = [
"clapper",
"click",
"credible",
"numpy",
"scipy",
"scikit-image",
......
......@@ -10,6 +10,7 @@ import typing
from collections.abc import Iterable, Iterator
import credible.curves
import matplotlib.figure
import numpy
import numpy.typing
......@@ -240,6 +241,7 @@ def run_binary(
# point measures on threshold
summary = dict(
split=name,
num_samples=len(y_labels),
threshold=use_threshold,
threshold_a_posteriori=(threshold_a_priori is None),
precision=sklearn.metrics.precision_score(
......@@ -248,13 +250,20 @@ def run_binary(
recall=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=pos_label
),
f1_score=sklearn.metrics.f1_score(
y_labels, y_predictions, pos_label=pos_label
),
average_precision_score=sklearn.metrics.average_precision_score(
y_labels, y_predictions, pos_label=pos_label
),
specificity=sklearn.metrics.recall_score(
y_labels, y_predictions, pos_label=neg_label
),
accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
f1_score=sklearn.metrics.f1_score(
y_labels, y_predictions, pos_label=pos_label
auc_score=sklearn.metrics.roc_auc_score(
y_labels,
y_predictions,
),
accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
)
# figures: score distributions
......@@ -471,7 +480,7 @@ def aggregate_pr(
Parameters
----------
data
A dictionary mapping split names to ROC curve data produced by
A dictionary mapping split names to Precision-Recall curve data produced by
:py:func:sklearn.metrics.precision_recall_curve`.
title
The title of the plot.
......@@ -503,8 +512,8 @@ def aggregate_pr(
legend = []
for name, (prec, recall, _) in data.items():
_auc = sklearn.metrics.auc(recall, prec)
label = f"{name} (AUC={_auc:.2f})"
_ap = credible.curves.average_metric([prec, recall])
label = f"{name} (AP={_ap:.2f})"
color = next(colorcycler)
style = next(linecycler)
......
......@@ -146,6 +146,14 @@ def evaluate(
with table_path.open("w") as f:
f.write(table)
machine_table_path = output_folder / "summary.json"
logger.info(
f"Also saving a machine-readable copy of measures at `{machine_table_path}`..."
)
with machine_table_path.open("w") as f:
json.dump(rows, f, indent=2)
figure_path = output_folder / "plots.pdf"
logger.info(f"Saving figures at `{figure_path}`...")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment