From 44c25df8a2b1e7d78a09d291e354e59a74fb594e Mon Sep 17 00:00:00 2001
From: Theophile GENTILHOMME <tgentilhomme@jurasix08.idiap.ch>
Date: Fri, 6 Apr 2018 16:41:21 +0200
Subject: [PATCH] Add and concanate click options, add threshold criterion

---
 bob/measure/script/commands.py       |  5 +-
 bob/measure/script/common_options.py | 97 ++++++++++++++++++++--------
 bob/measure/script/figure.py         | 41 +++++++++---
 bob/measure/utils.py                 |  6 +-
 4 files changed, 111 insertions(+), 38 deletions(-)

diff --git a/bob/measure/script/commands.py b/bob/measure/script/commands.py
index 2d2b87f..8552ff8 100644
--- a/bob/measure/script/commands.py
+++ b/bob/measure/script/commands.py
@@ -15,7 +15,8 @@ from bob.extension.scripts.click_helper import verbosity_option
 @common_options.open_file_mode_option()
 @common_options.output_plot_metric_option()
 @common_options.criterion_option()
-@common_options.threshold_option()
+@common_options.thresholds_option()
+@common_options.far_option()
 @verbosity_option()
 @click.pass_context
 def metrics(ctx, scores, test, **kwargs):
@@ -143,7 +144,7 @@ def epc(ctx, scores, **kwargs):
 @common_options.n_bins_option()
 @common_options.criterion_option()
 @common_options.axis_fontsize_option()
-@common_options.threshold_option()
+@common_options.thresholds_option()
 @verbosity_option()
 @click.pass_context
 def hist(ctx, scores, test, **kwargs):
diff --git a/bob/measure/script/common_options.py b/bob/measure/script/common_options.py
index d9691f1..6a5d85e 100644
--- a/bob/measure/script/common_options.py
+++ b/bob/measure/script/common_options.py
@@ -87,40 +87,60 @@ def semilogx_option(dflt= False, **kwargs):
             callback=callback, **kwargs)(func)
     return custom_semilogx_option
 
-def axes_val_option(dflt=None, **kwargs):
-    '''Get option for min/max values for axes. If one the default is None, no
-    default is used
-
+def list_float_option(name, short_name, desc, nitems=None, dflt=None, **kwargs):
+    '''Get option to get a list of float f
     Parameters
     ----------
 
+    name: str
+        name of the option
+    short_name: str
+        short name for the option
+    desc: str
+        short description for the option
+    nitems: :obj:`int`
+        if given, the parsed list must contains this number of items
     dflt: :any:`list`
-        List of default min/max values for axes. Must be of length 4
+        List of default  values for axes.
     '''
-    def custom_axes_val_option(func):
+    def custom_list_float_option(func):
         def callback(ctx, param, value):
             if value is not None:
                 tmp = value.split(',')
-                if len(tmp) != 4:
-                    raise click.BadParameter('Must provide 4 axis limits')
+                if nitems is not None and len(tmp) != nitems:
+                    raise click.BadParameter(
+                        '%s Must provide %d axis limits' % (name, nitems)
+                    )
                 try:
                     value = [float(i) for i in tmp]
                 except:
-                    raise click.BadParameter('Axis limits must be floats')
+                    raise click.BadParameter('Inputs of %s be floats' % name)
                 if None in value:
                     value = None
-                elif None not in dflt and len(dflt) == 4:
+                elif dflt is not None and None not in dflt and len(dflt) == nitems:
                     value = dflt if not all(
                         isinstance(x, float) for x in dflt
                     ) else None
-            ctx.meta['axlim'] = value
+            ctx.meta[name.replace('-', '_')] = value
             return value
         return click.option(
-            '-L', '--axlim', default=None, show_default=True,
-            help='min/max axes values separated by commas (min_x, max_x, '
-            'min_y, max_y)',
-            callback=callback, **kwargs)(func)
-    return custom_axes_val_option
+            '-'+short_name, '--'+name, default=None, show_default=True,
+            help=desc, callback=callback, **kwargs)(func)
+    return custom_list_float_option
+
+def axes_val_option(dflt=None, **kwargs):
+    return list_float_option(
+        name='axlim', short_name='L',
+        desc='min/max axes values separated by commas (min_x, max_x, min_y, max_y)',
+        nitems=4, dflt=dflt, **kwargs
+    )
+
+def thresholds_option(**kwargs):
+    return list_float_option(
+        name='thres', short_name='T',
+        desc='Given threshold for metrics computations',
+        nitems=None, dflt=None, **kwargs
+    )
 
 def axis_fontsize_option(dflt=8, **kwargs):
     '''Get option for axis font size'''
@@ -162,6 +182,20 @@ def fmr_line_at_option(**kwargs):
             callback=callback, **kwargs)(func)
     return custom_fmr_line_at_option
 
+def cost_option(**kwargs):
+    '''Get option to get cost for FAR'''
+    def custom_cost_option(func):
+        def callback(ctx, param, value):
+            if value < 0 or value > 1:
+                raise click.BadParameter("Cost for FAR must be betwen 0 and 1")
+            ctx.meta['cost'] = value
+            return value
+        return click.option(
+            '-C', '--cost', type=float, default=0.99, show_default=True,
+            help='Cost for FAR in minDCF',
+            callback=callback, **kwargs)(func)
+    return custom_cost_option
+
 def n_sys_option(**kwargs):
     '''Get the number of systems to be processed'''
     def custom_n_sys_option(func):
@@ -282,11 +316,19 @@ def open_file_mode_option(**kwargs):
             callback=callback, **kwargs)(func)
     return custom_open_file_mode_option
 
-def criterion_option(**kwargs):
-    '''Get option flag to tell which criteriom is used (default:eer)'''
+def criterion_option(lcriteria=['eer', 'hter', 'far'], **kwargs):
+    """Get option flag to tell which criteriom is used (default:eer)
+
+    Parameters
+    ----------
+
+    lcriteria : :any:`list`
+        List of possible criteria
+    """
     def custom_criterion_option(func):
         def callback(ctx, param, value):
-            list_accepted_crit = ['eer', 'hter']
+            list_accepted_crit = lcriteria if lcriteria is not None else \
+                    ['eer', 'hter', 'far']
             if value not in list_accepted_crit:
                 raise click.BadParameter('Incorrect value for `--criter`. '
                                          'Must be one of [`%s`]' %
@@ -299,17 +341,20 @@ def criterion_option(**kwargs):
             callback=callback, is_eager=True ,**kwargs)(func)
     return custom_criterion_option
 
-def threshold_option(**kwargs):
-    '''Get option for given threshold'''
-    def custom_threshold_option(func):
+
+def far_option(**kwargs):
+    '''Get option to get far value'''
+    def custom_far_option(func):
         def callback(ctx, param, value):
-            ctx.meta['thres'] = value
+            if value > 1 or value < 0:
+                raise click.BadParameter("FAR value should be between 0 and 1")
+            ctx.meta['far_value'] = value
             return value
         return click.option(
-            '--thres', type=click.FLOAT, default=None,
-            help='Given threshold for metrics computations',
+            '-f', '--far-value', type=click.FLOAT, default=1e-2,
+            help='The FAR value for which to compute metrics',
             callback=callback, show_default=True,**kwargs)(func)
-    return custom_threshold_option
+    return custom_far_option
 
 def rank_option(**kwargs):
     '''Get option for rank parameter'''
diff --git a/bob/measure/script/figure.py b/bob/measure/script/figure.py
index 41f67fd..7087000 100644
--- a/bob/measure/script/figure.py
+++ b/bob/measure/script/figure.py
@@ -203,8 +203,12 @@ class Metrics(MeasureBase):
     _open_mode: str
         Open mode of the output file (e.g. `w`, `a+`)
 
-    _thres: :obj:`float`
-        If given, uses this threshold instead of computing it
+    _thres: :any:`list`
+        If given, uses those threshold instead of computing them. Lenght of the
+        list must be the same as the number of systems.
+
+    _far: :obj:`float`
+        If given, uses this FAR to compute threshold
 
     _log: str
         Path to output log file
@@ -221,6 +225,16 @@ class Metrics(MeasureBase):
         self._open_mode = None if 'open_mode' not in ctx.meta else\
                 ctx.meta['open_mode']
         self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
+        if self._thres is not None :
+            if len(self._thres) == 1:
+                self._thres = self._thres * len(self.dev_names)
+            elif len(self._thres) != len(self.dev_names):
+                raise click.BadParameter(
+                    '#thresholds must be the same as #systems (%d)' \
+                    % len(self.dev_names)
+                )
+        self._far = None if 'far_value' not in ctx.meta else \
+        ctx.meta['far_value']
         self._log = None if 'log' not in ctx.meta else ctx.meta['log']
         self.log_file = sys.stdout
         if self._log is not None:
@@ -228,12 +242,12 @@ class Metrics(MeasureBase):
 
     def compute(self, idx, dev_score, dev_file=None,
                 test_score=None, test_file=None):
-        ''' Compute metrics thresholds and tables (FAR, FMR, FMNR, HTER) for
+        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
         given system inputs'''
         dev_neg, dev_pos, dev_fta, test_neg, test_pos, test_fta =\
                 self._process_scores(dev_score, test_score)
-        threshold = utils.get_thres(self._criter, dev_neg, dev_pos) \
-                if self._thres is None else self._thres
+        threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
+                if self._thres is None else self._thres[idx]
         if self._thres is None:
             click.echo("[Min. criterion: %s] Threshold on Development set `%s`: %e"\
                        % (self._criter.upper(), dev_file, threshold),
@@ -286,7 +300,7 @@ class Metrics(MeasureBase):
             test_frr_str = "%.3f%%" % (100 * test_frr)
             test_hter_str = "%.3f%%" % (100 * test_hter)
 
-            headers.append('Test % s' % self.test_names[idx])
+            headers.append('Test % s' % test_file)
             raws[0].append(test_fmr_str)
             raws[1].append(test_fnmr_str)
             raws[2].append(test_far_str)
@@ -559,8 +573,9 @@ class Hist(PlotBase):
     _nbins: :obj:`int`, str
     Number of bins. Default: `auto`
 
-    _thres: :obj:`float`
-        If given, this threshold will be used in the plots
+    _thres: :any:`list`
+        If given, uses those threshold instead of computing them. Lenght of the
+        list must be the same as the number of systems.
 
     _criter: str
         Criterion to compute threshold (eer or hter)
@@ -569,6 +584,14 @@ class Hist(PlotBase):
         super(Hist, self).__init__(ctx, scores, test, func_load)
         self._nbins = None if 'nbins' not in ctx.meta else ctx.meta['nbins']
         self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
+        if self._thres is not None and len(self._thres) != len(self.dev_names):
+            if len(self._thres) == 1:
+                self._thres = self._thres * len(self.dev_names)
+            else:
+                raise click.BadParameter(
+                    '#thresholds must be the same as #systems (%d)' \
+                    % len(self.dev_names)
+                )
         self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
         self._y_label = 'Dev. Scores \n (normalized)' if self._test else \
                 'Normalized Count'
@@ -581,7 +604,7 @@ class Hist(PlotBase):
         dev_neg, dev_pos, _, test_neg, test_pos, _ =\
                 self._process_scores(dev_score, test_score)
         threshold = utils.get_thres(self._criter, dev_neg, dev_pos) \
-                if self._thres is None else self._thres
+                if self._thres is None else self._thres[idx]
 
         fig = mpl.figure()
         if test_neg is not None:
diff --git a/bob/measure/utils.py b/bob/measure/utils.py
index 6cdc1e3..c938e4c 100644
--- a/bob/measure/utils.py
+++ b/bob/measure/utils.py
@@ -56,7 +56,7 @@ def get_fta(scores):
     fta_total += total
     return ((neg, pos), fta_sum / fta_total)
 
-def get_thres(criter, neg, pos):
+def get_thres(criter, neg, pos, far=1e-3):
     """Get threshold for the given positive/negatives scores and criterion
 
     Parameters
@@ -82,6 +82,10 @@ def get_thres(criter, neg, pos):
     elif criter == 'hter':
         from . import min_hter_threshold
         return min_hter_threshold(neg, pos)
+    elif criter == 'far':
+        from . import far_threshold
+        return far_threshold(neg, pos, far)
+
     else:
         raise click.UsageError("Incorrect plotting criterion: ``%s``" % criter)
 
-- 
GitLab