Commit bcccc06b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add a command for multi protocol (N-fold cross validation) analysis

parent 0ef7907e
Pipeline #21320 passed with stage
in 25 minutes and 49 seconds
......@@ -11,10 +11,12 @@ SCORE_FORMAT = (
CRITERIA = ('eer', 'min-hter', 'far')
@common_options.metrics_command(common_options.METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT,
command='bob measure metrics'), criteria=CRITERIA)
@common_options.metrics_command(
common_options.METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT,
command='bob measure metrics'),
criteria=CRITERIA)
def metrics(ctx, scores, evaluation, **kwargs):
process = figure.Metrics(ctx, scores, evaluation, load.split)
process.run()
......@@ -59,3 +61,15 @@ def hist(ctx, scores, evaluation, **kwargs):
def evaluate(ctx, scores, evaluation, **kwargs):
common_options.evaluate_flow(
ctx, scores, evaluation, metrics, roc, det, epc, hist, **kwargs)
@common_options.multi_metrics_command(
common_options.MULTI_METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT,
command='bob measure multi-metrics'),
criteria=CRITERIA)
def multi_metrics(ctx, scores, evaluation, protocols_number, **kwargs):
ctx.meta['min_arg'] = protocols_number * (2 if evaluation else 1)
process = figure.MultiMetrics(ctx, scores, evaluation, load.split)
process.run()
......@@ -21,7 +21,7 @@ def scores_argument(min_arg=1, force_eval=False, **kwargs):
----------
min_arg : int
the minimum number of file needed to evaluate a system. For example,
PAD functionalities needs licit abd spoof and therefore min_arg = 2
vulnerability analysis needs licit and spoof and therefore min_arg = 2
Returns
-------
......@@ -920,3 +920,67 @@ def evaluate_flow(ctx, scores, evaluation, metrics, roc, det, epc, hist,
ctx.forward(hist)
click.echo("Evaluate successfully completed!")
click.echo("[plots] => %s" % (ctx.meta['output']))
def n_protocols_option(required=True, **kwargs):
'''Get option for number of protocols.'''
def custom_n_protocols_option(func):
def callback(ctx, param, value):
value = abs(value)
ctx.meta['protocols_number'] = value
return value
return click.option(
'-pn', '--protocols-number', type=click.INT,
show_default=True, required=required,
help='The number of protocols of cross validation.',
callback=callback, **kwargs)(func)
return custom_n_protocols_option
def multi_metrics_command(docstring, criteria=('eer', 'min-hter', 'far')):
def custom_metrics_command(func):
func.__doc__ = docstring
@click.command('multi-metrics')
@scores_argument(nargs=-1)
@eval_option()
@n_protocols_option()
@table_option()
@output_log_metric_option()
@criterion_option(criteria)
@thresholds_option()
@far_option()
@legends_option()
@open_file_mode_option()
@verbosity_option()
@click.pass_context
@functools.wraps(func)
def wrapper(*args, **kwds):
return func(*args, **kwds)
return wrapper
return custom_metrics_command
MULTI_METRICS_HELP = """Multi protocol (cross-validation) metrics.
Prints a table that contains {names} for a given threshold criterion
({criteria}). The metrics are averaged over several protocols. The idea is
that each protocol corresponds to one fold in your cross-validation.
You need to provide as many as development score files as the number of
protocols per system. You can also provide evaluation files along with dev
files. If evaluation scores are provided, you must use flag `--eval`. The
number of protocols must be provided using the `--protocols-number` option.
{score_format}
Resulting table format can be changed using the `--tablefmt`.
Examples:
$ {command} -v {{p1,p2,p3}}/scores-dev
$ {command} -v -e {{p1,p2,p3}}/scores-{{dev,eval}}
$ {command} -v -e {{sys1,sys2}}/{{p1,p2,p3}}/scores-{{dev,eval}}
"""
......@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
import math
import sys
import os.path
import numpy
import click
import matplotlib
import matplotlib.pyplot as mpl
......@@ -128,7 +129,6 @@ class MeasureBase(object):
# and if only dev:
# [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]
# Things to do after the main iterative computations are done
@abstractmethod
def end_process(self):
......@@ -192,6 +192,29 @@ class Metrics(MeasureBase):
def get_thres(self, criterion, dev_neg, dev_pos, far):
return utils.get_thres(criterion, dev_neg, dev_pos, far)
def _numbers(self, neg, pos, threshold, fta):
from .. import farfrr
fmr, fnmr = farfrr(neg, pos, threshold)
far = fmr * (1 - fta)
frr = fta + fnmr * (1 - fta)
hter = (far + frr) / 2.0
ni = neg.shape[0] # number of impostors
fm = int(round(fmr * ni)) # number of false accepts
nc = pos.shape[0] # number of clients
fnm = int(round(fnmr * nc)) # number of false rejects
return fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc
def _strings(self, fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc):
fta_str = "%.1f%%" % (100 * fta)
fmr_str = "%.1f%% (%d/%d)" % (100 * fmr, fm, ni)
fnmr_str = "%.1f%% (%d/%d)" % (100 * fnmr, fnm, nc)
far_str = "%.1f%%" % (100 * far)
frr_str = "%.1f%%" % (100 * frr)
hter_str = "%.1f%%" % (100 * hter)
return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str
def compute(self, idx, input_scores, input_names):
''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
given system inputs'''
......@@ -204,6 +227,7 @@ class Metrics(MeasureBase):
threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
if self._thres is None else self._thres[idx]
title = self._legends[idx] if self._legends is not None else None
if self._thres is None:
far_str = ''
......@@ -219,63 +243,33 @@ class Metrics(MeasureBase):
"Development set `%s`: %e"
% (dev_file or title, threshold), file=self.log_file)
from .. import farfrr
dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
dev_far = dev_fmr * (1 - dev_fta)
dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
dev_hter = (dev_far + dev_frr) / 2.0
dev_ni = dev_neg.shape[0] # number of impostors
dev_fm = int(round(dev_fmr * dev_ni)) # number of false accepts
dev_nc = dev_pos.shape[0] # number of clients
dev_fnm = int(round(dev_fnmr * dev_nc)) # number of false rejects
dev_fta_str = "%.1f%%" % (100 * dev_fta)
dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
dev_far_str = "%.1f%%" % (100 * dev_far)
dev_frr_str = "%.1f%%" % (100 * dev_frr)
dev_hter_str = "%.1f%%" % (100 * dev_hter)
fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
self._strings(*self._numbers(
dev_neg, dev_pos, threshold, dev_fta))
headers = ['' or title, 'Development %s' % dev_file]
raws = [[self.names[0], dev_fta_str],
[self.names[1], dev_fmr_str],
[self.names[2], dev_fnmr_str],
[self.names[3], dev_far_str],
[self.names[4], dev_frr_str],
[self.names[5], dev_hter_str]]
rows = [[self.names[0], fta_str],
[self.names[1], fmr_str],
[self.names[2], fnmr_str],
[self.names[3], far_str],
[self.names[4], frr_str],
[self.names[5], hter_str]]
if self._eval:
# computes statistics for the eval set based on the threshold a priori
eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
eval_far = eval_fmr * (1 - eval_fta)
eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
eval_hter = (eval_far + eval_frr) / 2.0
eval_ni = eval_neg.shape[0] # number of impostors
eval_fm = int(round(eval_fmr * eval_ni)) # number of false accepts
eval_nc = eval_pos.shape[0] # number of clients
# number of false rejects
eval_fnm = int(round(eval_fnmr * eval_nc))
eval_fta_str = "%.1f%%" % (100 * eval_fta)
eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
eval_fmr, eval_fm, eval_ni)
eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
eval_fnmr, eval_fnm, eval_nc)
eval_far_str = "%.1f%%" % (100 * eval_far)
eval_frr_str = "%.1f%%" % (100 * eval_frr)
eval_hter_str = "%.1f%%" % (100 * eval_hter)
# computes statistics for the eval set based on the threshold a
# priori
fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
self._strings(*self._numbers(
eval_neg, eval_pos, threshold, eval_fta))
headers.append('Eval. % s' % eval_file)
raws[0].append(eval_fta_str)
raws[1].append(eval_fmr_str)
raws[2].append(eval_fnmr_str)
raws[3].append(eval_far_str)
raws[4].append(eval_frr_str)
raws[5].append(eval_hter_str)
rows[0].append(fta_str)
rows[1].append(fmr_str)
rows[2].append(fnmr_str)
rows[3].append(far_str)
rows[4].append(frr_str)
rows[5].append(hter_str)
click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)
click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
def end_process(self):
''' Close log file if needed'''
......@@ -283,6 +277,91 @@ class Metrics(MeasureBase):
self.log_file.close()
class MultiMetrics(Metrics):
'''Computes average of metrics based on several protocols (cross
validation)
Attributes
----------
log_file : str
output stream
names : tuple
List of names for the metrics.
'''
def __init__(self, ctx, scores, evaluation, func_load,
names=('NaNs Rate', 'False Positive Rate',
'False Negative Rate', 'False Accept Rate',
'False Reject Rate', 'Half Total Error Rate')):
super(MultiMetrics, self).__init__(
ctx, scores, evaluation, func_load, names=names)
self.headers = ['Methods'] + list(self.names)
if self._eval:
self.headers.insert(1, self.names[5] + ' (dev)')
self.rows = []
def _strings(self, metrics):
ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)
return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str
def compute(self, idx, input_scores, input_names):
'''Computes the average of metrics over several protocols.'''
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
step = 2 if self._eval else 1
self._dev_metrics = []
self._thresholds = []
for i in range(0, len(input_scores), step):
neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
threshold = self.get_thres(self._criterion, neg, pos, self._far) \
if self._thres is None else self._thres[idx]
self._thresholds.append(threshold)
self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
self._dev_metrics = numpy.array(self._dev_metrics)
if self._eval:
self._eval_metrics = []
for i in range(1, len(input_scores), step):
neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
threshold = self._thresholds[i // 2]
self._eval_metrics.append(
self._numbers(neg, pos, threshold, fta))
self._eval_metrics = numpy.array(self._eval_metrics)
title = self._legends[idx] if self._legends is not None else None
fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
self._strings(self._dev_metrics)
if self._eval:
self.rows.append([title, hter_str])
else:
self.rows.append([title, fta_str, fmr_str, fnmr_str,
far_str, frr_str, hter_str])
if self._eval:
# computes statistics for the eval set based on the threshold a
# priori
fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
self._strings(self._eval_metrics)
self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
far_str, frr_str, hter_str])
def end_process(self):
click.echo(tabulate(self.rows, self.headers,
self._tablefmt), file=self.log_file)
super(MultiMetrics, self).end_process()
class PlotBase(MeasureBase):
''' Base class for plots. Regroup several options and code
shared by the different plots
......@@ -586,7 +665,8 @@ class Hist(PlotBase):
# do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False)
if self._hide_dev and not self._eval:
raise click.BadParameter("You can only use --hide-dev along with --eval")
raise click.BadParameter(
"You can only use --hide-dev along with --eval")
# dev hist are displayed next to eval hist
self._ncols *= 1 if self._hide_dev or not self._eval else 2
......@@ -601,7 +681,7 @@ class Hist(PlotBase):
if self._legends is not None and len(self._legends) == self.n_systems \
and not self._hide_dev:
# use same legend for dev and eval if needed
self._legends = [x for pair in zip(self._legends,self._legends)
self._legends = [x for pair in zip(self._legends, self._legends)
for x in pair]
def compute(self, idx, input_scores, input_names):
......
......@@ -74,6 +74,7 @@ setup(
'bob.measure.cli': [
'evaluate = bob.measure.script.commands:evaluate',
'metrics = bob.measure.script.commands:metrics',
'multi-metrics = bob.measure.script.commands:multi_metrics',
'roc = bob.measure.script.commands:roc',
'det = bob.measure.script.commands:det',
'epc = bob.measure.script.commands:epc',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment