diff --git a/conda/meta.yaml b/conda/meta.yaml index a1bed5c468fbf89b641800779d7aab4f79888415..33434fb6550f0a393de6e6487f79e637238be3cc 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -25,6 +25,7 @@ requirements: - pip - clapper {{ clapper }} - click {{ click }} + - credible {{ credible }} - grad-cam {{ grad_cam }} - matplotlib {{ matplotlib }} - numpy {{ numpy }} @@ -44,6 +45,7 @@ requirements: - python >=3.10 - {{ pin_compatible('clapper') }} - {{ pin_compatible('click') }} + - {{ pin_compatible('credible') }} - {{ pin_compatible('grad-cam', max_pin='x.x') }} - {{ pin_compatible('matplotlib') }} - {{ pin_compatible('numpy') }} diff --git a/pyproject.toml b/pyproject.toml index 55db0f21b1b3c5415e41416f8a2ff074924c880b..0457cdff94604b93b689e4cf0176ac2080238010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ dependencies = [ "clapper", "click", + "credible", "numpy", "scipy", "scikit-image", diff --git a/src/mednet/engine/evaluator.py b/src/mednet/engine/evaluator.py index cd24260014c5374e692c9a10cce9427e859692dc..d8ba0d71fa8a89e4cff11f5a840389d227d3c9bc 100644 --- a/src/mednet/engine/evaluator.py +++ b/src/mednet/engine/evaluator.py @@ -10,6 +10,7 @@ import typing from collections.abc import Iterable, Iterator +import credible.curves import matplotlib.figure import numpy import numpy.typing @@ -240,6 +241,7 @@ def run_binary( # point measures on threshold summary = dict( split=name, + num_samples=len(y_labels), threshold=use_threshold, threshold_a_posteriori=(threshold_a_priori is None), precision=sklearn.metrics.precision_score( @@ -248,13 +250,20 @@ def run_binary( recall=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=pos_label ), + f1_score=sklearn.metrics.f1_score( + y_labels, y_predictions, pos_label=pos_label + ), + average_precision_score=sklearn.metrics.average_precision_score( + y_labels, y_predictions, pos_label=pos_label + ), specificity=sklearn.metrics.recall_score( y_labels, y_predictions, pos_label=neg_label ), - accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), - f1_score=sklearn.metrics.f1_score( - y_labels, y_predictions, pos_label=pos_label + auc_score=sklearn.metrics.roc_auc_score( + y_labels, + y_predictions, ), + accuracy=sklearn.metrics.accuracy_score(y_labels, y_predictions), ) # figures: score distributions @@ -471,7 +480,7 @@ def aggregate_pr( Parameters ---------- data - A dictionary mapping split names to ROC curve data produced by + A dictionary mapping split names to Precision-Recall curve data produced by :py:func:sklearn.metrics.precision_recall_curve`. title The title of the plot. @@ -503,8 +512,8 @@ def aggregate_pr( legend = [] for name, (prec, recall, _) in data.items(): - _auc = sklearn.metrics.auc(recall, prec) - label = f"{name} (AUC={_auc:.2f})" + _ap = credible.curves.average_metric([prec, recall]) + label = f"{name} (AP={_ap:.2f})" color = next(colorcycler) style = next(linecycler) diff --git a/src/mednet/scripts/evaluate.py b/src/mednet/scripts/evaluate.py index 3c21ef632a25bcac0e39ba7d7678dc5737b2637f..10c6896256fa233d1aa2f8a5902cf0bc9e9c9beb 100644 --- a/src/mednet/scripts/evaluate.py +++ b/src/mednet/scripts/evaluate.py @@ -146,6 +146,14 @@ def evaluate( with table_path.open("w") as f: f.write(table) + machine_table_path = output_folder / "summary.json" + + logger.info( + f"Also saving a machine-readable copy of measures at `{machine_table_path}`..." + ) + with machine_table_path.open("w") as f: + json.dump(rows, f, indent=2) + figure_path = output_folder / "plots.pdf" logger.info(f"Saving figures at `{figure_path}`...")