Skip to content
Snippets Groups Projects
Commit 661ab557 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.classify.evaluator] Fix small numpy-related API change on credible

parent 774e6ef2
No related branches found
No related tags found
No related merge requests found
Pipeline #95347 failed
......@@ -133,10 +133,12 @@ def run(
name: str,
predictions: typing.Sequence[Prediction],
binning: str | int,
rng: numpy.random.Generator,
threshold_a_priori: float | None = None,
credible_regions: bool = False,
) -> dict[str, typing.Any]:
"""Run inference and calculates measures for binary or multilabel classification.
"""Run inference and calculates measures for binary or multilabel
classification.
For multi-label problems, calculate the metrics in the "micro" sense by
first rasterizing all scores and labels (with :py:func:`numpy.ravel`), and
......@@ -152,6 +154,8 @@ def run(
The binning algorithm to use for computing the bin widths and
distribution for histograms. Choose from algorithms supported by
:py:func:`numpy.histogram`.
rng
An initialized numpy random number generator.
threshold_a_priori
A threshold to use, evaluated *a priori*, if must report single values.
If this value is not provided, an *a posteriori* threshold is calculated
......@@ -218,7 +222,7 @@ def run(
f"(samples = {len(predictions)}) - "
f"note this can be slow on very large datasets..."
)
f1 = credible.bayesian.metrics.f1_score(y_labels, y_predictions)
f1 = credible.bayesian.metrics.f1_score(y_labels, y_predictions, rng=rng)
roc_auc = credible.bayesian.metrics.roc_auc_score(y_labels, y_scores)
precision = credible.bayesian.metrics.precision_score(y_labels, y_predictions)
recall = credible.bayesian.metrics.recall_score(y_labels, y_predictions)
......
......@@ -114,6 +114,17 @@ logger = setup_cli_logger()
default=False,
cls=ResourceOption,
)
@click.option(
"--seed",
"-s",
help="""Seed to use for the random number generator (used when doing Monte Carlo "
simulations required for the evaluation of credible regions for F1-score).""",
show_default=True,
required=False,
default=42,
type=click.IntRange(min=0),
cls=ResourceOption,
)
@verbosity_option(logger=logger, expose_value=False)
def evaluate(
predictions: pathlib.Path,
......@@ -122,6 +133,7 @@ def evaluate(
binning: str,
plot: bool,
credible_regions: bool,
seed: int,
**_, # ignored
) -> None: # numpydoc ignore=PR01
"""Evaluate predictions (from a model) on a classification task."""
......@@ -129,6 +141,7 @@ def evaluate(
import typing
import matplotlib.backends.backend_pdf
import numpy
from ...engine.classify.evaluator import make_plots, make_table, run
from ..utils import save_json_metadata, save_json_with_backup
......@@ -167,6 +180,8 @@ def evaluate(
or can not be converted to a float. Check your input.""",
)
rng = numpy.random.default_rng(seed)
results: dict[str, dict[str, typing.Any]] = dict()
for k, v in predict_data.items():
logger.info(f"Computing performance on split `{k}`...")
......@@ -174,6 +189,7 @@ def evaluate(
name=k,
predictions=v,
binning=int(binning) if binning.isnumeric() else binning,
rng=rng,
threshold_a_priori=use_threshold,
credible_regions=credible_regions,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment