Commit 4a2698ea authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Improve the metrics command with AUC and relaxed threshold criteria

parent e19c925c
Pipeline #36785 passed with stage
in 17 minutes and 28 seconds
......@@ -18,6 +18,8 @@ from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
from tabulate import tabulate
from .pad_commands import CRITERIA
from .error_utils import calc_threshold
logger = logging.getLogger(__name__)
......@@ -94,8 +96,11 @@ Examples:
default=["train", "dev", "eval"],
)
@bool_option("sort", "s", "whether the table should be sorted.", True)
@common_options.criterion_option(lcriteria=CRITERIA, check=False)
@common_options.far_option()
@common_options.table_option()
@common_options.output_log_metric_option()
@common_options.decimal_option(dflt=2, short='-dec')
@verbosity_option()
@click.pass_context
def cross(
......@@ -109,6 +114,7 @@ def cross(
pai_names,
groups,
sort,
decimal,
verbose,
**kwargs
):
......@@ -161,7 +167,7 @@ def cross(
threshold = metrics[(database, protocol, algorithm, "dev")][1]
else:
try:
threshold = eer_threshold(neg, pos)
threshold = calc_threshold(ctx.meta["criterion"], pos, [neg], neg, ctx.meta['far_value'])
except RuntimeError:
logger.error("Something wrong with {}".format(score_path))
raise
......@@ -185,8 +191,8 @@ def cross(
rows = []
# sort the algorithms based on HTER test, EER dev, EER train
train_protocol = protocols[databases.index(train_database)]
if sort:
train_protocol = protocols[databases.index(train_database)]
def sort_key(alg):
r = []
......@@ -212,7 +218,7 @@ def cross(
cell += [far, frr, hter]
else:
cell += [hter]
cell = [round(c * 100, 1) for c in cell]
cell = [round(c * 100, decimal) for c in cell]
rows[-1].extend(cell)
title = " Trained on {} ".format(train_database)
......
......@@ -50,7 +50,7 @@ def metrics_option(
name="metrics",
help="List of metrics to print. Provide a string with comma separated metric "
"names. For possible values see the default value.",
default="apcer_pais,apcer_ap,bpcer,acer,fta,fpr,fnr,hter,far,frr,precision,recall,f1_score",
default="apcer_pais,apcer_ap,bpcer,acer,fta,fpr,fnr,hter,far,frr,precision,recall,f1_score,auc,auc-log-scale",
**kwargs
):
"""The metrics option"""
......@@ -157,6 +157,7 @@ def gen(ctx, outdir, mean_match, mean_non_match, n_sys, **kwargs):
command="bob pad metrics",
),
criteria=CRITERIA,
check_criteria=False,
epilog="""\b
More Examples:
\b
......
......@@ -2,7 +2,7 @@
import bob.measure.script.figure as measure_figure
from bob.measure.utils import get_fta_list
from bob.measure import farfrr, precision_recall, f_score
from bob.measure import farfrr, precision_recall, f_score, roc_auc_score
import bob.bio.base.script.figure as bio_figure
from .error_utils import calc_threshold, apcer_bpcer
import click
......@@ -65,6 +65,11 @@ class Metrics(bio_figure.Metrics):
# f_score
f1_score = f_score(all_negs, pos, threshold, 1)
# auc
auc = roc_auc_score(all_negs, pos)
auc_log = roc_auc_score(all_negs, pos, log_scale=True)
metrics = dict(
apcer_pais=apcer_pais,
apcer_ap=apcer_ap,
......@@ -83,13 +88,15 @@ class Metrics(bio_figure.Metrics):
precision=precision,
recall=recall,
f1_score=f1_score,
auc=auc,
)
metrics["auc-log-scale"] = auc_log
return metrics
def _strings(self, metrics):
n_dec = ".%df" % self._decimal
for k, v in metrics.items():
if k in ("precision", "recall", "f1_score"):
if k in ("precision", "recall", "f1_score", "auc", "auc-log-scale"):
metrics[k] = "%s" % format(v, n_dec)
elif k in ("np", "nn", "fp", "fn"):
continue
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment