Commit dc0de546 authored by Theophile GENTILHOMME's avatar Theophile GENTILHOMME

Change the way the scores arguments are passed to the compute() function: it...

Change the way the scores arguments are passed to the compute() function: it now does not rely on dev,eval pairs anymore and can take any number of different files (e.g. train)
parent 2b95ebd9
Pipeline #19285 passed with stage
in 25 minutes and 28 seconds
......@@ -37,7 +37,7 @@ def metrics(ctx, scores, evaluation, **kwargs):
$ bob measure metrics {dev,eval}-scores1 {dev,eval}-scores2
"""
process = figure.Metrics(ctx, scores, evaluation, load.split_files)
process = figure.Metrics(ctx, scores, evaluation, load.split)
process.run()
@click.command()
......@@ -76,7 +76,7 @@ def roc(ctx, scores, evaluation, **kwargs):
$ bob measure roc -o my_roc.pdf dev-scores1 eval-scores1
"""
process = figure.Roc(ctx, scores, evaluation, load.split_files)
process = figure.Roc(ctx, scores, evaluation, load.split)
process.run()
@click.command()
......@@ -114,11 +114,11 @@ def det(ctx, scores, evaluation, **kwargs):
$ bob measure det -o my_det.pdf dev-scores1 eval-scores1
"""
process = figure.Det(ctx, scores, evaluation, load.split_files)
process = figure.Det(ctx, scores, evaluation, load.split)
process.run()
@click.command()
@common_options.scores_argument(eval_mandatory=True, nargs=-1)
@common_options.scores_argument(min_arg=2, nargs=-1)
@common_options.output_plot_file_option(default_out='epc.pdf')
@common_options.title_option()
@common_options.titles_option()
......@@ -144,7 +144,7 @@ def epc(ctx, scores, **kwargs):
$ bob measure epc -o my_epc.pdf dev-scores1 eval-scores1
"""
process = figure.Epc(ctx, scores, True, load.split_files)
process = figure.Epc(ctx, scores, True, load.split)
process.run()
@click.command()
......@@ -184,7 +184,7 @@ def hist(ctx, scores, evaluation, **kwargs):
$ bob measure hist --criter hter --show-dev dev-scores1 eval-scores1
"""
process = figure.Hist(ctx, scores, evaluation, load.split_files)
process = figure.Hist(ctx, scores, evaluation, load.split)
process.run()
@click.command()
......
......@@ -9,17 +9,15 @@ from bob.extension.scripts.click_helper import (bool_option, list_float_option)
LOGGER = logging.getLogger(__name__)
def scores_argument(eval_mandatory=False, min_len=1, **kwargs):
def scores_argument(min_arg=1, **kwargs):
"""Get the argument for scores, and add `dev-scores` and `eval-scores` in
the context when `--evaluation` flag is on (default)
Parameters
----------
eval_mandatory :
If evaluation files are mandatory
min_len :
The min lenght of inputs files that are needed. If eval_mandatory is
True, this quantity is multiplied by 2.
min_arg : int
the minimum number of file needed to evaluate a system. For example,
PAD functionalities needs licit abd spoof and therefore min_arg = 2
Returns
-------
......@@ -28,44 +26,27 @@ def scores_argument(eval_mandatory=False, min_len=1, **kwargs):
"""
def custom_scores_argument(func):
def callback(ctx, param, value):
length = len(value)
min_arg = min_len or 1
ctx.meta['min_arg'] = min_arg
if length < min_arg:
min_a = min_arg or 1
mutli = 1
error = ''
if 'evaluation' in ctx.meta and ctx.meta['evaluation']:
mutli += 1
error += '- %d evaluation file(s) \n' % min_a
if 'train' in ctx.meta and ctx.meta['train']:
mutli += 1
error += '- %d training file(s) \n' % min_a
#add more test here if other inputs are needed
min_a *= mutli
ctx.meta['min_arg'] = min_a
if len(value) < 1 or len(value) % ctx.meta['min_arg'] != 0:
raise click.BadParameter(
'You must provide at least %d score files' % min_arg,
ctx=ctx
'The number of provided scores must be > 0 and a multiple of %d '
'because the following files are required:\n'
'- %d development file(s)\n' % (min_a, min_arg or 1) +
error, ctx=ctx
)
else:
ctx.meta['scores'] = value
step = 1
if eval_mandatory or ctx.meta['evaluation']:
step = 2
if (length % (min_arg * 2)) != 0:
pref = 'T' if eval_mandatory else \
('When `--evaluation` flag is on t')
raise click.BadParameter(
'%sest-score(s) must '
'be provided along with dev-score(s). '
'You must provide at least %d score files.' \
% (pref, min_arg * 2), ctx=ctx
)
for arg in range(min_arg):
ctx.meta['dev_scores_%d' % arg] = [
value[i] for i in range(arg * step, length,
min_arg * step)
]
if step > 1:
ctx.meta['eval_scores_%d' % arg] = [
value[i] for i in range((arg * step + 1),
length, min_arg * step)
]
ctx.meta['n_sys'] = len(ctx.meta['dev_scores_0'])
if 'titles' in ctx.meta and \
len(ctx.meta['titles']) != ctx.meta['n_sys']:
raise click.BadParameter(
'#titles not equal to #sytems', ctx=ctx
)
ctx.meta['scores'] = value
return value
return click.argument(
'scores', type=click.Path(exists=True),
......
......@@ -56,14 +56,17 @@ class MeasureBase(object):
self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
self._ctx = ctx
self.func_load = func_load
self.dev_names, self.eval_names, self.dev_scores, self.eval_scores = \
self._load_files()
self.n_sytem = len(self.dev_names[0]) # at least one set of dev scores
self._titles = None if 'titles' not in ctx.meta else ctx.meta['titles']
if self._titles is not None and len(self._titles) != self.n_sytem:
self._eval = evaluation
self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
if len(scores) < 1 or len(scores) % self._min_arg != 0:
raise click.BadParameter(
'Number of argument must be a non-zero multiple of %d' % self._min_arg
)
self.n_systems = int(len(scores) / self._min_arg)
if self._titles is not None and len(self._titles) != self.n_systems:
raise click.BadParameter("Number of titles must be equal to the "
"number of systems")
self._eval = evaluation
def run(self):
""" Generate outputs (e.g. metrics, files, pdf plots).
......@@ -80,40 +83,24 @@ class MeasureBase(object):
#with the dev (and eval) scores of each system
# Note that more than one dev or eval scores score can be passed to
# each system
for idx in range(self.n_sytem):
dev_score = []
eval_score = []
dev_file = []
eval_file = []
for arg in range(self._min_arg):
dev_score.append(self.dev_scores[arg][idx])
dev_file.append(self.dev_names[arg][idx])
eval_score.append(self.eval_scores[arg][idx] \
if self.eval_scores[arg] is not None else None)
eval_file.append(self.eval_names[arg][idx] \
if self.eval_names[arg] is not None else None)
if self._min_arg == 1: # most of measure only take one arg
# so do not pass a list of one arg
#does the main computations/plottings here
self.compute(idx, dev_score[0], dev_file[0], eval_score[0],
eval_file[0])
else:
#does the main computations/plottings here
self.compute(idx, dev_score, dev_file, eval_score, eval_file)
for idx in range(self.n_systems):
input_scores, input_names = self._load_files(
self._scores[idx:(idx + self._min_arg)]
)
self.compute(idx, input_scores, input_names)
#setup final configuration, plotting properties, ...
self.end_process()
#protected functions that need to be overwritten
def init_process(self):
""" Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
before iterating through the different sytems.
before iterating through the different systems.
Should reimplemented in derived classes"""
pass
#Main computations are done here in the subclasses
@abstractmethod
def compute(self, idx, dev_score, dev_file=None,
eval_score=None, eval_file=None):
def compute(self, idx, input_scores, input_names):
"""Compute metrics or plots from the given scores provided by
:py:func:`~bob.measure.script.figure.MeasureBase.run`.
Should reimplemented in derived classes
......@@ -122,20 +109,10 @@ class MeasureBase(object):
----------
idx : :obj:`int`
index of the system
dev_score:
Development scores. Can be a tuple (neg, pos) of
:py:class:`numpy.ndarray` (e.g.
:py:func:`~bob.measure.script.figure.Roc.compute`) or
a :any:`list` of tuples of :py:class:`numpy.ndarray` (e.g. cmc)
dev_file : str
name of the dev file without extension
eval_score:
eval scores. Can be a tuple (neg, pos) of
:py:class:`numpy.ndarray` (e.g.
:py:func:`~bob.measure.script.figure.Roc.compute`) or
a :any:`list` of tuples of :py:class:`numpy.ndarray` (e.g. cmc)
eval_file : str
name of the eval file without extension
input_scores: :any:`list`
list of scores returned by the loading function
input_names: :any:`list`
list of base names for the input file of the system
"""
pass
......@@ -143,65 +120,29 @@ class MeasureBase(object):
@abstractmethod
def end_process(self):
""" Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
after iterating through the different sytems.
after iterating through the different systems.
Should reimplemented in derived classes"""
pass
#common protected functions
def _load_files(self):
''' Load the input files and returns
def _load_files(self, filepaths):
''' Load the input files and return the base names of the files
Returns
-------
dev_scores: :any:`list`: A list that contains, for each required
dev score file, the output of ``func_load``
eval_scores: :any:`list`: A list that contains, for each required
eval score file, the output of ``func_load``
scores: :any:`list`:
A list that contains the output of
``func_load`` for the given files
basenames: :any:`list`:
A list of basenames for the given files
'''
def _extract_file_names(filenames):
if filenames is None:
return None
res = []
for file_path in filenames:
name = os.path.basename(file_path)
res.append(name.split(".")[0])
return res
dev_scores = []
eval_scores = []
dev_files = []
eval_files = []
for arg in range(self._min_arg):
key = 'dev_scores_%d' % arg
dev_paths = self._scores if key not in self._ctx.meta else \
self._ctx.meta[key]
key = 'eval_scores_%d' % arg
eval_paths = None if key not in self._ctx.meta else \
self._ctx.meta[key]
dev_files.append(_extract_file_names(dev_paths))
eval_files.append(_extract_file_names(eval_paths))
dev_scores.append(self.func_load(dev_paths))
eval_scores.append(self.func_load(eval_paths))
return (dev_files, eval_files, dev_scores, eval_scores)
def _process_scores(self, dev_score, eval_score):
'''Process score files and return neg/pos/fta for eval and dev'''
dev_neg = dev_pos = dev_fta = eval_neg = eval_pos = eval_fta = None
if dev_score[0] is not None:
(dev_neg, dev_pos), dev_fta = utils.get_fta(dev_score)
if dev_neg is None:
raise click.UsageError("While loading dev-score file")
if self._eval and eval_score is not None and eval_score[0] is not None:
eval_score, eval_fta = utils.get_fta(eval_score)
eval_neg, eval_pos = eval_score
if eval_neg is None:
raise click.UsageError("While loading eval-score file")
return (dev_neg, dev_pos, dev_fta, eval_neg, eval_pos, eval_fta)
scores = []
basenames = []
for filename in filepaths:
basenames.append(os.path.basename(filename).split(".")[0])
scores.append(self.func_load(filename))
return scores, basenames
class Metrics(MeasureBase):
''' Compute metrics from score files
......@@ -234,12 +175,16 @@ class Metrics(MeasureBase):
if self._log is not None:
self.log_file = open(self._log, self._open_mode)
def compute(self, idx, dev_score, dev_file=None,
eval_score=None, eval_file=None):
def compute(self, idx, input_scores, input_names):
''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
given system inputs'''
dev_neg, dev_pos, dev_fta, eval_neg, eval_pos, eval_fta =\
self._process_scores(dev_score, eval_score)
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
dev_file = input_names[0]
if self._eval:
eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
eval_file = input_names[1]
threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
if self._thres is None else self._thres[idx]
title = self._titles[idx] if self._titles is not None else None
......@@ -281,7 +226,7 @@ class Metrics(MeasureBase):
['FRR', dev_frr_str],
['HTER', dev_hter_str]]
if self._eval and eval_neg is not None:
if self._eval:
# computes statistics for the eval set based on the threshold a priori
eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
eval_far = eval_fmr * (1 - eval_fta)
......@@ -341,8 +286,7 @@ class PlotBase(MeasureBase):
if 'style' in ctx.meta:
mpl.style.use(ctx.meta['style'])
self._nb_figs = 2 if self._eval and self._split else 1
self._multi_plots = len(self.dev_scores) > 1
self._colors = utils.get_colors(len(self.dev_scores))
self._colors = utils.get_colors(self.n_systems)
self._states = ['Development', 'Evaluation']
self._title = None if 'title' not in ctx.meta else ctx.meta['title']
self._x_label = None if 'x_label' not in ctx.meta else\
......@@ -420,7 +364,7 @@ class PlotBase(MeasureBase):
def _label(self, base, name, idx):
if self._titles is not None and len(self._titles) > idx:
return self._titles[idx]
if self._multi_plots:
if self.n_systems > 1:
return base + (" %d (%s)" % (idx + 1, name))
return base + (" (%s)" % name)
......@@ -439,12 +383,16 @@ class Roc(PlotBase):
if self._axlim is None:
self._axlim = [1e-4, 1.0, 1e-4, 1.0]
def compute(self, idx, dev_score, dev_file=None,
eval_score=None, eval_file=None):
def compute(self, idx, input_scores, input_names):
''' Plot ROC for dev and eval data using
:py:func:`bob.measure.plot.roc`'''
dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
self._process_scores(dev_score, eval_score)
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
dev_file = input_names[0]
if self._eval:
eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
eval_file = input_names[1]
mpl.figure(1)
if self._eval:
linestyle = '-' if not self._split else LINESTYLES[idx % 14]
......@@ -491,12 +439,16 @@ class Det(PlotBase):
if self._x_rotation is None:
self._x_rotation = 50
def compute(self, idx, dev_score, dev_file=None,
eval_score=None, eval_file=None):
def compute(self, idx, input_scores, input_names):
''' Plot DET for dev and eval data using
:py:func:`bob.measure.plot.det`'''
dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
self._process_scores(dev_score, eval_score)
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
dev_file = input_names[0]
if self._eval:
eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
eval_file = input_names[1]
mpl.figure(1)
if self._eval and eval_neg is not None:
linestyle = '-' if not self._split else LINESTYLES[idx % 14]
......@@ -538,7 +490,7 @@ class Epc(PlotBase):
''' Handles the plotting of EPC '''
def __init__(self, ctx, scores, evaluation, func_load):
super(Epc, self).__init__(ctx, scores, evaluation, func_load)
if 'eval_scores_0' not in self._ctx.meta:
if self._min_arg != 2:
raise click.UsageError("EPC requires dev and eval score files")
self._title = self._title or 'EPC'
self._x_label = self._x_label or r'$\alpha$'
......@@ -548,10 +500,15 @@ class Epc(PlotBase):
self._nb_figs = 1
self._far_at = None
def compute(self, idx, dev_score, dev_file, eval_score, eval_file=None):
def compute(self, idx, input_scores, input_names):
''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
self._process_scores(dev_score, eval_score)
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
dev_file = input_names[0]
if self._eval:
eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
eval_file = input_names[1]
plot.epc(
dev_neg, dev_pos, eval_neg, eval_pos, self._points,
color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
......@@ -583,11 +540,12 @@ class Hist(PlotBase):
self._title_base = self._title or 'Scores'
self._end_setup_plot = False
def compute(self, idx, dev_score, dev_file=None,
eval_score=None, eval_file=None):
def compute(self, idx, input_scores, input_names):
''' Draw histograms of negative and positive scores.'''
dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
self._get_neg_pos_thres(idx, dev_score, eval_score)
self._get_neg_pos_thres(idx, input_scores, input_names)
dev_file = input_names[0]
eval_file = None if len(input_names) != 2 else input_names[1]
fig = mpl.figure()
if eval_neg is not None and self._show_dev:
......@@ -648,10 +606,12 @@ class Hist(PlotBase):
mpl.legend(lines, labels,
loc='best', fancybox=True, framealpha=0.5)
def _get_neg_pos_thres(self, idx, dev_score, eval_score):
dev_neg, dev_pos, _, eval_neg, eval_pos, _ = self._process_scores(
dev_score, eval_score
)
def _get_neg_pos_thres(self, idx, input_scores, input_names):
neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
eval_neg = eval_pos = None
if self._eval:
eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
threshold = utils.get_thres(
self._criter, dev_neg,
dev_pos
......
......@@ -53,6 +53,37 @@ def get_fta(scores):
fta_total += total
return ((neg, pos), fta_sum / fta_total)
def get_fta_list(scores):
""" Get FTAs for a list of scores
Parameters
----------
scores: :any:`list`
list of scores
Returns
-------
neg_list: :any:`list`
list of negatives
pos_list: :any:`list`
list of positives
fta_list: :any:`list`
list of FTAs
"""
neg_list = []
pos_list = []
fta_list = []
for score in scores:
neg = pos = fta = None
if score is not None:
(neg, pos), fta = get_fta(score)
if neg is None:
raise ValueError("While loading dev-score file")
neg_list.append(neg)
pos_list.append(pos)
fta_list.append(fta)
return (neg_list, pos_list, fta_list)
def get_thres(criter, neg, pos, far=None):
"""Get threshold for the given positive/negatives scores and criterion
......
......@@ -114,6 +114,7 @@ Utilities
.. autosummary::
bob.measure.utils.remove_nan
bob.measure.utils.get_fta
bob.measure.utils.get_fta_list
bob.measure.utils.get_thres
bob.measure.utils.get_colors
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment