Commit 44c25df8 authored by Theophile GENTILHOMME's avatar Theophile GENTILHOMME

Add and concanate click options, add threshold criterion

parent f6e324a7
Pipeline #18454 passed with stage
in 58 minutes and 37 seconds
......@@ -15,7 +15,8 @@ from bob.extension.scripts.click_helper import verbosity_option
@common_options.open_file_mode_option()
@common_options.output_plot_metric_option()
@common_options.criterion_option()
@common_options.threshold_option()
@common_options.thresholds_option()
@common_options.far_option()
@verbosity_option()
@click.pass_context
def metrics(ctx, scores, test, **kwargs):
......@@ -143,7 +144,7 @@ def epc(ctx, scores, **kwargs):
@common_options.n_bins_option()
@common_options.criterion_option()
@common_options.axis_fontsize_option()
@common_options.threshold_option()
@common_options.thresholds_option()
@verbosity_option()
@click.pass_context
def hist(ctx, scores, test, **kwargs):
......
......@@ -87,40 +87,60 @@ def semilogx_option(dflt= False, **kwargs):
callback=callback, **kwargs)(func)
return custom_semilogx_option
def axes_val_option(dflt=None, **kwargs):
'''Get option for min/max values for axes. If one the default is None, no
default is used
def list_float_option(name, short_name, desc, nitems=None, dflt=None, **kwargs):
'''Get option to get a list of float f
Parameters
----------
name: str
name of the option
short_name: str
short name for the option
desc: str
short description for the option
nitems: :obj:`int`
if given, the parsed list must contains this number of items
dflt: :any:`list`
List of default min/max values for axes. Must be of length 4
List of default values for axes.
'''
def custom_axes_val_option(func):
def custom_list_float_option(func):
def callback(ctx, param, value):
if value is not None:
tmp = value.split(',')
if len(tmp) != 4:
raise click.BadParameter('Must provide 4 axis limits')
if nitems is not None and len(tmp) != nitems:
raise click.BadParameter(
'%s Must provide %d axis limits' % (name, nitems)
)
try:
value = [float(i) for i in tmp]
except:
raise click.BadParameter('Axis limits must be floats')
raise click.BadParameter('Inputs of %s be floats' % name)
if None in value:
value = None
elif None not in dflt and len(dflt) == 4:
elif dflt is not None and None not in dflt and len(dflt) == nitems:
value = dflt if not all(
isinstance(x, float) for x in dflt
) else None
ctx.meta['axlim'] = value
ctx.meta[name.replace('-', '_')] = value
return value
return click.option(
'-L', '--axlim', default=None, show_default=True,
help='min/max axes values separated by commas (min_x, max_x, '
'min_y, max_y)',
callback=callback, **kwargs)(func)
return custom_axes_val_option
'-'+short_name, '--'+name, default=None, show_default=True,
help=desc, callback=callback, **kwargs)(func)
return custom_list_float_option
def axes_val_option(dflt=None, **kwargs):
return list_float_option(
name='axlim', short_name='L',
desc='min/max axes values separated by commas (min_x, max_x, min_y, max_y)',
nitems=4, dflt=dflt, **kwargs
)
def thresholds_option(**kwargs):
return list_float_option(
name='thres', short_name='T',
desc='Given threshold for metrics computations',
nitems=None, dflt=None, **kwargs
)
def axis_fontsize_option(dflt=8, **kwargs):
'''Get option for axis font size'''
......@@ -162,6 +182,20 @@ def fmr_line_at_option(**kwargs):
callback=callback, **kwargs)(func)
return custom_fmr_line_at_option
def cost_option(**kwargs):
'''Get option to get cost for FAR'''
def custom_cost_option(func):
def callback(ctx, param, value):
if value < 0 or value > 1:
raise click.BadParameter("Cost for FAR must be betwen 0 and 1")
ctx.meta['cost'] = value
return value
return click.option(
'-C', '--cost', type=float, default=0.99, show_default=True,
help='Cost for FAR in minDCF',
callback=callback, **kwargs)(func)
return custom_cost_option
def n_sys_option(**kwargs):
'''Get the number of systems to be processed'''
def custom_n_sys_option(func):
......@@ -282,11 +316,19 @@ def open_file_mode_option(**kwargs):
callback=callback, **kwargs)(func)
return custom_open_file_mode_option
def criterion_option(**kwargs):
'''Get option flag to tell which criteriom is used (default:eer)'''
def criterion_option(lcriteria=['eer', 'hter', 'far'], **kwargs):
"""Get option flag to tell which criteriom is used (default:eer)
Parameters
----------
lcriteria : :any:`list`
List of possible criteria
"""
def custom_criterion_option(func):
def callback(ctx, param, value):
list_accepted_crit = ['eer', 'hter']
list_accepted_crit = lcriteria if lcriteria is not None else \
['eer', 'hter', 'far']
if value not in list_accepted_crit:
raise click.BadParameter('Incorrect value for `--criter`. '
'Must be one of [`%s`]' %
......@@ -299,17 +341,20 @@ def criterion_option(**kwargs):
callback=callback, is_eager=True ,**kwargs)(func)
return custom_criterion_option
def threshold_option(**kwargs):
'''Get option for given threshold'''
def custom_threshold_option(func):
def far_option(**kwargs):
'''Get option to get far value'''
def custom_far_option(func):
def callback(ctx, param, value):
ctx.meta['thres'] = value
if value > 1 or value < 0:
raise click.BadParameter("FAR value should be between 0 and 1")
ctx.meta['far_value'] = value
return value
return click.option(
'--thres', type=click.FLOAT, default=None,
help='Given threshold for metrics computations',
'-f', '--far-value', type=click.FLOAT, default=1e-2,
help='The FAR value for which to compute metrics',
callback=callback, show_default=True,**kwargs)(func)
return custom_threshold_option
return custom_far_option
def rank_option(**kwargs):
'''Get option for rank parameter'''
......
......@@ -203,8 +203,12 @@ class Metrics(MeasureBase):
_open_mode: str
Open mode of the output file (e.g. `w`, `a+`)
_thres: :obj:`float`
If given, uses this threshold instead of computing it
_thres: :any:`list`
If given, uses those threshold instead of computing them. Lenght of the
list must be the same as the number of systems.
_far: :obj:`float`
If given, uses this FAR to compute threshold
_log: str
Path to output log file
......@@ -221,6 +225,16 @@ class Metrics(MeasureBase):
self._open_mode = None if 'open_mode' not in ctx.meta else\
ctx.meta['open_mode']
self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
if self._thres is not None :
if len(self._thres) == 1:
self._thres = self._thres * len(self.dev_names)
elif len(self._thres) != len(self.dev_names):
raise click.BadParameter(
'#thresholds must be the same as #systems (%d)' \
% len(self.dev_names)
)
self._far = None if 'far_value' not in ctx.meta else \
ctx.meta['far_value']
self._log = None if 'log' not in ctx.meta else ctx.meta['log']
self.log_file = sys.stdout
if self._log is not None:
......@@ -228,12 +242,12 @@ class Metrics(MeasureBase):
def compute(self, idx, dev_score, dev_file=None,
test_score=None, test_file=None):
''' Compute metrics thresholds and tables (FAR, FMR, FMNR, HTER) for
''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
given system inputs'''
dev_neg, dev_pos, dev_fta, test_neg, test_pos, test_fta =\
self._process_scores(dev_score, test_score)
threshold = utils.get_thres(self._criter, dev_neg, dev_pos) \
if self._thres is None else self._thres
threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
if self._thres is None else self._thres[idx]
if self._thres is None:
click.echo("[Min. criterion: %s] Threshold on Development set `%s`: %e"\
% (self._criter.upper(), dev_file, threshold),
......@@ -286,7 +300,7 @@ class Metrics(MeasureBase):
test_frr_str = "%.3f%%" % (100 * test_frr)
test_hter_str = "%.3f%%" % (100 * test_hter)
headers.append('Test % s' % self.test_names[idx])
headers.append('Test % s' % test_file)
raws[0].append(test_fmr_str)
raws[1].append(test_fnmr_str)
raws[2].append(test_far_str)
......@@ -559,8 +573,9 @@ class Hist(PlotBase):
_nbins: :obj:`int`, str
Number of bins. Default: `auto`
_thres: :obj:`float`
If given, this threshold will be used in the plots
_thres: :any:`list`
If given, uses those threshold instead of computing them. Lenght of the
list must be the same as the number of systems.
_criter: str
Criterion to compute threshold (eer or hter)
......@@ -569,6 +584,14 @@ class Hist(PlotBase):
super(Hist, self).__init__(ctx, scores, test, func_load)
self._nbins = None if 'nbins' not in ctx.meta else ctx.meta['nbins']
self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
if self._thres is not None and len(self._thres) != len(self.dev_names):
if len(self._thres) == 1:
self._thres = self._thres * len(self.dev_names)
else:
raise click.BadParameter(
'#thresholds must be the same as #systems (%d)' \
% len(self.dev_names)
)
self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
self._y_label = 'Dev. Scores \n (normalized)' if self._test else \
'Normalized Count'
......@@ -581,7 +604,7 @@ class Hist(PlotBase):
dev_neg, dev_pos, _, test_neg, test_pos, _ =\
self._process_scores(dev_score, test_score)
threshold = utils.get_thres(self._criter, dev_neg, dev_pos) \
if self._thres is None else self._thres
if self._thres is None else self._thres[idx]
fig = mpl.figure()
if test_neg is not None:
......
......@@ -56,7 +56,7 @@ def get_fta(scores):
fta_total += total
return ((neg, pos), fta_sum / fta_total)
def get_thres(criter, neg, pos):
def get_thres(criter, neg, pos, far=1e-3):
"""Get threshold for the given positive/negatives scores and criterion
Parameters
......@@ -82,6 +82,10 @@ def get_thres(criter, neg, pos):
elif criter == 'hter':
from . import min_hter_threshold
return min_hter_threshold(neg, pos)
elif criter == 'far':
from . import far_threshold
return far_threshold(neg, pos, far)
else:
raise click.UsageError("Incorrect plotting criterion: ``%s``" % criter)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment