From 4b010f321de426decfed55e62e6bc4705db85f64 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 27 Feb 2024 13:14:12 +0100 Subject: [PATCH] [scripts,engine] Implement fixes on evaluation (closes #20), and prepare for MLflow integration (issue #60) * Use average-precision instead of AUC for precision-recall figures * Report average precision on tables * Include dependence on credible (prepare to handle issue #43) * Also saves result table in JSON format (prepare to handle issue #60) --- conda/meta.yaml | 2 ++ pyproject.toml | 1 + src/mednet/engine/evaluator.py | 21 +++++++++++++++------ src/mednet/scripts/evaluate.py | 8 ++++++++ 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/conda/meta.yaml b/conda/meta.yaml index a1bed5c4..33434fb6 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -25,6 +25,7 @@ requirements: - pip - clapper {{ clapper }} - click {{ click }} + - credible {{ credible }} - grad-cam {{ grad_cam }} - matplotlib {{ matplotlib }} - numpy {{ numpy }} @@ -44,6 +45,7 @@ requirements: - python >=3.10 - {{ pin_compatible('clapper') }} - {{ pin_compatible('click') }} + - {{ pin_compatible('credible') }} - {{ pin_compatible('grad-cam', max_pin='x.x') }} - {{ pin_compatible('matplotlib') }} - {{ pin_compatible('numpy') }} diff --git a/pyproject.toml b/pyproject.toml index 55db0f21..0457cdff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ dependencies = [ "clapper", "click", + "credible", "numpy", "scipy", "scikit-image", diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py index cd242600..d8ba0d71 100644 --- a/src/mednet/engine/evaluator.py +++ b/src/mednet/engine/evaluator.py @@ -10,6 +10,7 @@ import typing from collections.abc import Iterable, Iterator +import credible.curves import matplotlib.figure import numpy import numpy.typing @@ -240,6 +241,7 @@ def run_binary( # point measures on threshold summary = dict( split=name, + num_samples=len(y_labels), threshold=use_threshold, threshold_a_posteriori=(threshold_a_priori is None), precision=sklearn.metrics.precision_score( @@ -248,13 +250,20 @@ def run_binary( recall=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=pos_label ), + f1_score=sklearn.metrics.f1_score( + y_labels, y_predictions, pos_label=pos_label + ), + average_precision_score=sklearn.metrics.average_precision_score( + y_labels, y_predictions, pos_label=pos_label + ), specificity=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=neg_label ), - accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), - f1_score=sklearn.metrics.f1_score( - y_labels, y_predictions, pos_label=pos_label + auc_score=sklearn.metrics.roc_auc_score( + y_labels, + y_predictions, ), + accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), ) # figures: score distributions @@ -471,7 +480,7 @@ def aggregate_pr( Parameters ---------- data - A dictionary mapping split names to ROC curve data produced by + A dictionary mapping split names to Precision-Recall curve data produced by :py:func:sklearn.metrics.precision_recall_curve`. title The title of the plot. @@ -503,8 +512,8 @@ def aggregate_pr( legend = [] for name, (prec, recall, _) in data.items(): - _auc = sklearn.metrics.auc(recall, prec) - label = f"{name} (AUC={_auc:.2f})" + _ap = credible.curves.average_metric([prec, recall]) + label = f"{name} (AP={_ap:.2f})" color = next(colorcycler) style = next(linecycler) diff --git a/src/mednet/scripts/evaluate.py b/src/mednet/scripts/evaluate.py index 3c21ef63..10c68962 100644 --- a/src/mednet/scripts/evaluate.py +++ b/src/mednet/scripts/evaluate.py @@ -146,6 +146,14 @@ def evaluate( with table_path.open("w") as f: f.write(table) + machine_table_path = output_folder / "summary.json" + + logger.info( + f"Also saving a machine-readable copy of measures at `{machine_table_path}`..." + ) + with machine_table_path.open("w") as f: + json.dump(rows, f, indent=2) + figure_path = output_folder / "plots.pdf" logger.info(f"Saving figures at `{figure_path}`...") -- GitLab