Commit bcccc06b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add a command for multi protocol (N-fold cross validation) analysis

parent 0ef7907e
Pipeline #21320 passed with stage
in 25 minutes and 49 seconds
...@@ -11,10 +11,12 @@ SCORE_FORMAT = ( ...@@ -11,10 +11,12 @@ SCORE_FORMAT = (
CRITERIA = ('eer', 'min-hter', 'far') CRITERIA = ('eer', 'min-hter', 'far')
@common_options.metrics_command(common_options.METRICS_HELP.format( @common_options.metrics_command(
common_options.METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER', names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT, criteria=CRITERIA, score_format=SCORE_FORMAT,
command='bob measure metrics'), criteria=CRITERIA) command='bob measure metrics'),
criteria=CRITERIA)
def metrics(ctx, scores, evaluation, **kwargs): def metrics(ctx, scores, evaluation, **kwargs):
process = figure.Metrics(ctx, scores, evaluation, load.split) process = figure.Metrics(ctx, scores, evaluation, load.split)
process.run() process.run()
...@@ -59,3 +61,15 @@ def hist(ctx, scores, evaluation, **kwargs): ...@@ -59,3 +61,15 @@ def hist(ctx, scores, evaluation, **kwargs):
def evaluate(ctx, scores, evaluation, **kwargs): def evaluate(ctx, scores, evaluation, **kwargs):
common_options.evaluate_flow( common_options.evaluate_flow(
ctx, scores, evaluation, metrics, roc, det, epc, hist, **kwargs) ctx, scores, evaluation, metrics, roc, det, epc, hist, **kwargs)
@common_options.multi_metrics_command(
common_options.MULTI_METRICS_HELP.format(
names='FtA, FAR, FRR, FMR, FMNR, HTER',
criteria=CRITERIA, score_format=SCORE_FORMAT,
command='bob measure multi-metrics'),
criteria=CRITERIA)
def multi_metrics(ctx, scores, evaluation, protocols_number, **kwargs):
ctx.meta['min_arg'] = protocols_number * (2 if evaluation else 1)
process = figure.MultiMetrics(ctx, scores, evaluation, load.split)
process.run()
...@@ -21,7 +21,7 @@ def scores_argument(min_arg=1, force_eval=False, **kwargs): ...@@ -21,7 +21,7 @@ def scores_argument(min_arg=1, force_eval=False, **kwargs):
---------- ----------
min_arg : int min_arg : int
the minimum number of file needed to evaluate a system. For example, the minimum number of file needed to evaluate a system. For example,
PAD functionalities needs licit abd spoof and therefore min_arg = 2 vulnerability analysis needs licit and spoof and therefore min_arg = 2
Returns Returns
------- -------
...@@ -920,3 +920,67 @@ def evaluate_flow(ctx, scores, evaluation, metrics, roc, det, epc, hist, ...@@ -920,3 +920,67 @@ def evaluate_flow(ctx, scores, evaluation, metrics, roc, det, epc, hist,
ctx.forward(hist) ctx.forward(hist)
click.echo("Evaluate successfully completed!") click.echo("Evaluate successfully completed!")
click.echo("[plots] => %s" % (ctx.meta['output'])) click.echo("[plots] => %s" % (ctx.meta['output']))
def n_protocols_option(required=True, **kwargs):
'''Get option for number of protocols.'''
def custom_n_protocols_option(func):
def callback(ctx, param, value):
value = abs(value)
ctx.meta['protocols_number'] = value
return value
return click.option(
'-pn', '--protocols-number', type=click.INT,
show_default=True, required=required,
help='The number of protocols of cross validation.',
callback=callback, **kwargs)(func)
return custom_n_protocols_option
def multi_metrics_command(docstring, criteria=('eer', 'min-hter', 'far')):
def custom_metrics_command(func):
func.__doc__ = docstring
@click.command('multi-metrics')
@scores_argument(nargs=-1)
@eval_option()
@n_protocols_option()
@table_option()
@output_log_metric_option()
@criterion_option(criteria)
@thresholds_option()
@far_option()
@legends_option()
@open_file_mode_option()
@verbosity_option()
@click.pass_context
@functools.wraps(func)
def wrapper(*args, **kwds):
return func(*args, **kwds)
return wrapper
return custom_metrics_command
MULTI_METRICS_HELP = """Multi protocol (cross-validation) metrics.
Prints a table that contains {names} for a given threshold criterion
({criteria}). The metrics are averaged over several protocols. The idea is
that each protocol corresponds to one fold in your cross-validation.
You need to provide as many as development score files as the number of
protocols per system. You can also provide evaluation files along with dev
files. If evaluation scores are provided, you must use flag `--eval`. The
number of protocols must be provided using the `--protocols-number` option.
{score_format}
Resulting table format can be changed using the `--tablefmt`.
Examples:
$ {command} -v {{p1,p2,p3}}/scores-dev
$ {command} -v -e {{p1,p2,p3}}/scores-{{dev,eval}}
$ {command} -v -e {{sys1,sys2}}/{{p1,p2,p3}}/scores-{{dev,eval}}
"""
...@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod ...@@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod
import math import math
import sys import sys
import os.path import os.path
import numpy
import click import click
import matplotlib import matplotlib
import matplotlib.pyplot as mpl import matplotlib.pyplot as mpl
...@@ -128,7 +129,6 @@ class MeasureBase(object): ...@@ -128,7 +129,6 @@ class MeasureBase(object):
# and if only dev: # and if only dev:
# [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)] # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]
# Things to do after the main iterative computations are done # Things to do after the main iterative computations are done
@abstractmethod @abstractmethod
def end_process(self): def end_process(self):
...@@ -192,6 +192,29 @@ class Metrics(MeasureBase): ...@@ -192,6 +192,29 @@ class Metrics(MeasureBase):
def get_thres(self, criterion, dev_neg, dev_pos, far): def get_thres(self, criterion, dev_neg, dev_pos, far):
return utils.get_thres(criterion, dev_neg, dev_pos, far) return utils.get_thres(criterion, dev_neg, dev_pos, far)
def _numbers(self, neg, pos, threshold, fta):
from .. import farfrr
fmr, fnmr = farfrr(neg, pos, threshold)
far = fmr * (1 - fta)
frr = fta + fnmr * (1 - fta)
hter = (far + frr) / 2.0
ni = neg.shape[0] # number of impostors
fm = int(round(fmr * ni)) # number of false accepts
nc = pos.shape[0] # number of clients
fnm = int(round(fnmr * nc)) # number of false rejects
return fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc
def _strings(self, fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc):
fta_str = "%.1f%%" % (100 * fta)
fmr_str = "%.1f%% (%d/%d)" % (100 * fmr, fm, ni)
fnmr_str = "%.1f%% (%d/%d)" % (100 * fnmr, fnm, nc)
far_str = "%.1f%%" % (100 * far)
frr_str = "%.1f%%" % (100 * frr)
hter_str = "%.1f%%" % (100 * hter)
return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str
def compute(self, idx, input_scores, input_names): def compute(self, idx, input_scores, input_names):
''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
given system inputs''' given system inputs'''
...@@ -204,6 +227,7 @@ class Metrics(MeasureBase): ...@@ -204,6 +227,7 @@ class Metrics(MeasureBase):
threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \ threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
if self._thres is None else self._thres[idx] if self._thres is None else self._thres[idx]
title = self._legends[idx] if self._legends is not None else None title = self._legends[idx] if self._legends is not None else None
if self._thres is None: if self._thres is None:
far_str = '' far_str = ''
...@@ -219,63 +243,33 @@ class Metrics(MeasureBase): ...@@ -219,63 +243,33 @@ class Metrics(MeasureBase):
"Development set `%s`: %e" "Development set `%s`: %e"
% (dev_file or title, threshold), file=self.log_file) % (dev_file or title, threshold), file=self.log_file)
from .. import farfrr fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold) self._strings(*self._numbers(
dev_far = dev_fmr * (1 - dev_fta) dev_neg, dev_pos, threshold, dev_fta))
dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
dev_hter = (dev_far + dev_frr) / 2.0
dev_ni = dev_neg.shape[0] # number of impostors
dev_fm = int(round(dev_fmr * dev_ni)) # number of false accepts
dev_nc = dev_pos.shape[0] # number of clients
dev_fnm = int(round(dev_fnmr * dev_nc)) # number of false rejects
dev_fta_str = "%.1f%%" % (100 * dev_fta)
dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
dev_far_str = "%.1f%%" % (100 * dev_far)
dev_frr_str = "%.1f%%" % (100 * dev_frr)
dev_hter_str = "%.1f%%" % (100 * dev_hter)
headers = ['' or title, 'Development %s' % dev_file] headers = ['' or title, 'Development %s' % dev_file]
raws = [[self.names[0], dev_fta_str], rows = [[self.names[0], fta_str],
[self.names[1], dev_fmr_str], [self.names[1], fmr_str],
[self.names[2], dev_fnmr_str], [self.names[2], fnmr_str],
[self.names[3], dev_far_str], [self.names[3], far_str],
[self.names[4], dev_frr_str], [self.names[4], frr_str],
[self.names[5], dev_hter_str]] [self.names[5], hter_str]]
if self._eval: if self._eval:
# computes statistics for the eval set based on the threshold a priori # computes statistics for the eval set based on the threshold a
eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold) # priori
eval_far = eval_fmr * (1 - eval_fta) fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
eval_frr = eval_fta + eval_fnmr * (1 - eval_fta) self._strings(*self._numbers(
eval_hter = (eval_far + eval_frr) / 2.0 eval_neg, eval_pos, threshold, eval_fta))
eval_ni = eval_neg.shape[0] # number of impostors
eval_fm = int(round(eval_fmr * eval_ni)) # number of false accepts
eval_nc = eval_pos.shape[0] # number of clients
# number of false rejects
eval_fnm = int(round(eval_fnmr * eval_nc))
eval_fta_str = "%.1f%%" % (100 * eval_fta)
eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
eval_fmr, eval_fm, eval_ni)
eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
eval_fnmr, eval_fnm, eval_nc)
eval_far_str = "%.1f%%" % (100 * eval_far)
eval_frr_str = "%.1f%%" % (100 * eval_frr)
eval_hter_str = "%.1f%%" % (100 * eval_hter)
headers.append('Eval. % s' % eval_file) headers.append('Eval. % s' % eval_file)
raws[0].append(eval_fta_str) rows[0].append(fta_str)
raws[1].append(eval_fmr_str) rows[1].append(fmr_str)
raws[2].append(eval_fnmr_str) rows[2].append(fnmr_str)
raws[3].append(eval_far_str) rows[3].append(far_str)
raws[4].append(eval_frr_str) rows[4].append(frr_str)
raws[5].append(eval_hter_str) rows[5].append(hter_str)
click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file) click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
def end_process(self): def end_process(self):
''' Close log file if needed''' ''' Close log file if needed'''
...@@ -283,6 +277,91 @@ class Metrics(MeasureBase): ...@@ -283,6 +277,91 @@ class Metrics(MeasureBase):
self.log_file.close() self.log_file.close()
class MultiMetrics(Metrics):
'''Computes average of metrics based on several protocols (cross
validation)
Attributes
----------
log_file : str
output stream
names : tuple
List of names for the metrics.
'''
def __init__(self, ctx, scores, evaluation, func_load,
names=('NaNs Rate', 'False Positive Rate',
'False Negative Rate', 'False Accept Rate',
'False Reject Rate', 'Half Total Error Rate')):
super(MultiMetrics, self).__init__(
ctx, scores, evaluation, func_load, names=names)
self.headers = ['Methods'] + list(self.names)
if self._eval:
self.headers.insert(1, self.names[5] + ' (dev)')
self.rows = []
def _strings(self, metrics):
ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)
return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str
def compute(self, idx, input_scores, input_names):
'''Computes the average of metrics over several protocols.'''
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
step = 2 if self._eval else 1
self._dev_metrics = []
self._thresholds = []
for i in range(0, len(input_scores), step):
neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
threshold = self.get_thres(self._criterion, neg, pos, self._far) \
if self._thres is None else self._thres[idx]
self._thresholds.append(threshold)
self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
self._dev_metrics = numpy.array(self._dev_metrics)
if self._eval:
self._eval_metrics = []
for i in range(1, len(input_scores), step):
neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
threshold = self._thresholds[i // 2]
self._eval_metrics.append(
self._numbers(neg, pos, threshold, fta))
self._eval_metrics = numpy.array(self._eval_metrics)
title = self._legends[idx] if self._legends is not None else None
fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
self._strings(self._dev_metrics)
if self._eval:
self.rows.append([title, hter_str])
else:
self.rows.append([title, fta_str, fmr_str, fnmr_str,
far_str, frr_str, hter_str])
if self._eval:
# computes statistics for the eval set based on the threshold a
# priori
fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
self._strings(self._eval_metrics)
self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
far_str, frr_str, hter_str])
def end_process(self):
click.echo(tabulate(self.rows, self.headers,
self._tablefmt), file=self.log_file)
super(MultiMetrics, self).end_process()
class PlotBase(MeasureBase): class PlotBase(MeasureBase):
''' Base class for plots. Regroup several options and code ''' Base class for plots. Regroup several options and code
shared by the different plots shared by the different plots
...@@ -586,7 +665,8 @@ class Hist(PlotBase): ...@@ -586,7 +665,8 @@ class Hist(PlotBase):
# do not display dev histo # do not display dev histo
self._hide_dev = ctx.meta.get('hide_dev', False) self._hide_dev = ctx.meta.get('hide_dev', False)
if self._hide_dev and not self._eval: if self._hide_dev and not self._eval:
raise click.BadParameter("You can only use --hide-dev along with --eval") raise click.BadParameter(
"You can only use --hide-dev along with --eval")
# dev hist are displayed next to eval hist # dev hist are displayed next to eval hist
self._ncols *= 1 if self._hide_dev or not self._eval else 2 self._ncols *= 1 if self._hide_dev or not self._eval else 2
...@@ -601,7 +681,7 @@ class Hist(PlotBase): ...@@ -601,7 +681,7 @@ class Hist(PlotBase):
if self._legends is not None and len(self._legends) == self.n_systems \ if self._legends is not None and len(self._legends) == self.n_systems \
and not self._hide_dev: and not self._hide_dev:
# use same legend for dev and eval if needed # use same legend for dev and eval if needed
self._legends = [x for pair in zip(self._legends,self._legends) self._legends = [x for pair in zip(self._legends, self._legends)
for x in pair] for x in pair]
def compute(self, idx, input_scores, input_names): def compute(self, idx, input_scores, input_names):
......
...@@ -74,6 +74,7 @@ setup( ...@@ -74,6 +74,7 @@ setup(
'bob.measure.cli': [ 'bob.measure.cli': [
'evaluate = bob.measure.script.commands:evaluate', 'evaluate = bob.measure.script.commands:evaluate',
'metrics = bob.measure.script.commands:metrics', 'metrics = bob.measure.script.commands:metrics',
'multi-metrics = bob.measure.script.commands:multi_metrics',
'roc = bob.measure.script.commands:roc', 'roc = bob.measure.script.commands:roc',
'det = bob.measure.script.commands:det', 'det = bob.measure.script.commands:det',
'epc = bob.measure.script.commands:epc', 'epc = bob.measure.script.commands:epc',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment