From 4b010f321de426decfed55e62e6bc4705db85f64 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 27 Feb 2024 13:14:12 +0100
Subject: [PATCH] [scripts,engine] Implement fixes on evaluation (closes #20),
 and prepare for MLflow integration (issue #60)

* Use average-precision instead of AUC for precision-recall figures
* Report average precision on tables
* Include dependence on credible (prepare to handle issue #43)
* Also saves result table in JSON format (prepare to handle issue #60)
---
 conda/meta.yaml                |  2 ++
 pyproject.toml                 |  1 +
 src/mednet/engine/evaluator.py | 21 +++++++++++++++------
 src/mednet/scripts/evaluate.py |  8 ++++++++
 4 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/conda/meta.yaml b/conda/meta.yaml
index a1bed5c4..33434fb6 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -25,6 +25,7 @@ requirements:
     - pip
     - clapper {{ clapper }}
     - click {{ click }}
+    - credible {{ credible }}
     - grad-cam {{ grad_cam }}
     - matplotlib {{ matplotlib }}
     - numpy {{ numpy }}
@@ -44,6 +45,7 @@ requirements:
     - python >=3.10
     - {{ pin_compatible('clapper') }}
     - {{ pin_compatible('click') }}
+    - {{ pin_compatible('credible') }}
     - {{ pin_compatible('grad-cam', max_pin='x.x') }}
     - {{ pin_compatible('matplotlib') }}
     - {{ pin_compatible('numpy') }}
diff --git a/pyproject.toml b/pyproject.toml
index 55db0f21..0457cdff 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ classifiers = [
 dependencies = [
   "clapper",
   "click",
+  "credible",
   "numpy",
   "scipy",
   "scikit-image",
diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py
index cd242600..d8ba0d71 100644
--- a/src/mednet/engine/evaluator.py
+++ b/src/mednet/engine/evaluator.py
@@ -10,6 +10,7 @@ import typing
 
 from collections.abc import Iterable, Iterator
 
+import credible.curves
 import matplotlib.figure
 import numpy
 import numpy.typing
@@ -240,6 +241,7 @@ def run_binary(
     # point measures on threshold
     summary = dict(
         split=name,
+        num_samples=len(y_labels),
         threshold=use_threshold,
         threshold_a_posteriori=(threshold_a_priori is None),
         precision=sklearn.metrics.precision_score(
@@ -248,13 +250,20 @@ def run_binary(
         recall=sklearn.metrics.recall_score(
             y_labels, y_predictions, pos_label=pos_label
         ),
+        f1_score=sklearn.metrics.f1_score(
+            y_labels, y_predictions, pos_label=pos_label
+        ),
+        average_precision_score=sklearn.metrics.average_precision_score(
+            y_labels, y_predictions, pos_label=pos_label
+        ),
         specificity=sklearn.metrics.recall_score(
             y_labels, y_predictions, pos_label=neg_label
         ),
-        accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
-        f1_score=sklearn.metrics.f1_score(
-            y_labels, y_predictions, pos_label=pos_label
+        auc_score=sklearn.metrics.roc_auc_score(
+            y_labels,
+            y_predictions,
         ),
+        accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions),
     )
 
     # figures: score distributions
@@ -471,7 +480,7 @@ def aggregate_pr(
     Parameters
     ----------
     data
-        A dictionary mapping split names to ROC curve data produced by
+        A dictionary mapping split names to Precision-Recall curve data produced by
         :py:func:sklearn.metrics.precision_recall_curve`.
     title
         The title of the plot.
@@ -503,8 +512,8 @@ def aggregate_pr(
         legend = []
 
         for name, (prec, recall, _) in data.items():
-            _auc = sklearn.metrics.auc(recall, prec)
-            label = f"{name} (AUC={_auc:.2f})"
+            _ap = credible.curves.average_metric([prec, recall])
+            label = f"{name} (AP={_ap:.2f})"
             color = next(colorcycler)
             style = next(linecycler)
 
diff --git a/src/mednet/scripts/evaluate.py b/src/mednet/scripts/evaluate.py
index 3c21ef63..10c68962 100644
--- a/src/mednet/scripts/evaluate.py
+++ b/src/mednet/scripts/evaluate.py
@@ -146,6 +146,14 @@ def evaluate(
         with table_path.open("w") as f:
             f.write(table)
 
+        machine_table_path = output_folder / "summary.json"
+
+        logger.info(
+            f"Also saving a machine-readable copy of measures at `{machine_table_path}`..."
+        )
+        with machine_table_path.open("w") as f:
+            json.dump(rows, f, indent=2)
+
         figure_path = output_folder / "plots.pdf"
         logger.info(f"Saving figures at `{figure_path}`...")
 
-- 
GitLab