Commit 21255dc6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'metrics' into 'master'

Improve the metrics command with AUC and relaxed threshold criteria

See merge request !70
parents 9b6c5b9e 4a2698ea
Pipeline #36807 failed with stages
in 9 minutes and 32 seconds
......@@ -18,6 +18,8 @@ from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
from tabulate import tabulate
from .pad_commands import CRITERIA
from .error_utils import calc_threshold
logger = logging.getLogger(__name__)
......@@ -94,8 +96,11 @@ Examples:
default=["train", "dev", "eval"],
)
@bool_option("sort", "s", "whether the table should be sorted.", True)
@common_options.criterion_option(lcriteria=CRITERIA, check=False)
@common_options.far_option()
@common_options.table_option()
@common_options.output_log_metric_option()
@common_options.decimal_option(dflt=2, short='-dec')
@verbosity_option()
@click.pass_context
def cross(
......@@ -109,6 +114,7 @@ def cross(
pai_names,
groups,
sort,
decimal,
verbose,
**kwargs
):
......@@ -161,7 +167,7 @@ def cross(
threshold = metrics[(database, protocol, algorithm, "dev")][1]
else:
try:
threshold = eer_threshold(neg, pos)
threshold = calc_threshold(ctx.meta["criterion"], pos, [neg], neg, ctx.meta['far_value'])
except RuntimeError:
logger.error("Something wrong with {}".format(score_path))
raise
......@@ -185,8 +191,8 @@ def cross(
rows = []
# sort the algorithms based on HTER test, EER dev, EER train
train_protocol = protocols[databases.index(train_database)]
if sort:
train_protocol = protocols[databases.index(train_database)]
def sort_key(alg):
r = []
......@@ -212,7 +218,7 @@ def cross(
cell += [far, frr, hter]
else:
cell += [hter]
cell = [round(c * 100, 1) for c in cell]
cell = [round(c * 100, decimal) for c in cell]
rows[-1].extend(cell)
title = " Trained on {} ".format(train_database)
......
......@@ -50,7 +50,7 @@ def metrics_option(
name="metrics",
help="List of metrics to print. Provide a string with comma separated metric "
"names. For possible values see the default value.",
default="apcer_pais,apcer_ap,bpcer,acer,fta,fpr,fnr,hter,far,frr,precision,recall,f1_score",
default="apcer_pais,apcer_ap,bpcer,acer,fta,fpr,fnr,hter,far,frr,precision,recall,f1_score,auc,auc-log-scale",
**kwargs
):
"""The metrics option"""
......@@ -157,6 +157,7 @@ def gen(ctx, outdir, mean_match, mean_non_match, n_sys, **kwargs):
command="bob pad metrics",
),
criteria=CRITERIA,
check_criteria=False,
epilog="""\b
More Examples:
\b
......
......@@ -2,7 +2,7 @@
import bob.measure.script.figure as measure_figure
from bob.measure.utils import get_fta_list
from bob.measure import farfrr, precision_recall, f_score
from bob.measure import farfrr, precision_recall, f_score, roc_auc_score
import bob.bio.base.script.figure as bio_figure
from .error_utils import calc_threshold, apcer_bpcer
import click
......@@ -65,6 +65,11 @@ class Metrics(bio_figure.Metrics):
# f_score
f1_score = f_score(all_negs, pos, threshold, 1)
# auc
auc = roc_auc_score(all_negs, pos)
auc_log = roc_auc_score(all_negs, pos, log_scale=True)
metrics = dict(
apcer_pais=apcer_pais,
apcer_ap=apcer_ap,
......@@ -83,13 +88,15 @@ class Metrics(bio_figure.Metrics):
precision=precision,
recall=recall,
f1_score=f1_score,
auc=auc,
)
metrics["auc-log-scale"] = auc_log
return metrics
def _strings(self, metrics):
n_dec = ".%df" % self._decimal
for k, v in metrics.items():
if k in ("precision", "recall", "f1_score"):
if k in ("precision", "recall", "f1_score", "auc", "auc-log-scale"):
metrics[k] = "%s" % format(v, n_dec)
elif k in ("np", "nn", "fp", "fn"):
continue
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment