figure.py 27.7 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8 9 10 11 12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
15

16 17 18 19 20 21 22 23 24 25
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
26 27
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

28
    def __init__(self, ctx, scores, evaluation, func_load):
29 30 31 32 33 34 35
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
36 37 38 39
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
40 41 42
        func_load : Function that is used to load the input files
        """
        self._scores = scores
43
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
44 45
        self._ctx = ctx
        self.func_load = func_load
46
        self._legends = None if 'legends' not in ctx.meta else ctx.meta['legends']
47 48 49 50 51 52 53
        self._eval = evaluation
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
54 55
        if self._legends is not None and len(self._legends) != self.n_systems:
            raise click.BadParameter("Number of legends must be equal to the "
56
                                     "number of systems")
57 58 59 60 61 62 63 64 65 66

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
67
        # init matplotlib, log files, ...
68
        self.init_process()
69 70
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
71 72
        # Note that more than one dev or eval scores score can be passed to
        # each system
73
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
74
            # load scores for each system: get the corresponding arrays and
75
            # base-name of files
76
            input_scores, input_names = self._load_files(
77 78 79 80 81 82
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
83
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
84 85
            )
            self.compute(idx, input_scores, input_names)
86
        # setup final configuration, plotting properties, ...
87 88
        self.end_process()

89
    # protected functions that need to be overwritten
90 91
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
92
        before iterating through the different systems.
93 94 95
        Should reimplemented in derived classes"""
        pass

96
    # Main computations are done here in the subclasses
97
    @abstractmethod
98
    def compute(self, idx, input_scores, input_names):
99
        """Compute metrics or plots from the given scores provided by
100 101 102 103 104 105 106
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
107 108 109 110
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
111 112 113
        """
        pass

114
    # Things to do after the main iterative computations are done
115 116
    @abstractmethod
    def end_process(self):
117
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
118
        after iterating through the different systems.
119
        Should reimplemented in derived classes"""
120 121
        pass

122
    # common protected functions
123

124 125
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
126 127 128

        Returns
        -------
129 130 131 132 133
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
134
        '''
135 136 137 138 139 140
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
141

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
142

143 144 145 146 147 148 149 150
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
151

152 153
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
154
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
155
            ctx.meta['tablefmt']
156
        self._criterion = None if 'criterion' not in ctx.meta else \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
157
            ctx.meta['criterion']
158
        self._open_mode = None if 'open_mode' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
159
            ctx.meta['open_mode']
160
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
161
        if self._thres is not None:
162
            if len(self._thres) == 1:
163 164
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
165
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
166
                    '#thresholds must be the same as #systems (%d)'
167
                    % len(self.n_systems)
168 169
                )
        self._far = None if 'far_value' not in ctx.meta else \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
170
            ctx.meta['far_value']
171 172 173 174 175
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

176
    def compute(self, idx, input_scores, input_names):
177
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
178
        given system inputs'''
179 180 181 182 183 184 185
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

186
        threshold = utils.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
187
            if self._thres is None else self._thres[idx]
188
        title = self._legends[idx] if self._legends is not None else None
189
        if self._thres is None:
190
            far_str = ''
191
            if self._criterion == 'far' and self._far is not None:
192
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
193
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
194 195 196
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
197 198
                       file=self.log_file)
        else:
199
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
200
                       "Development set `%s`: %e"
201
                       % (dev_file or title, threshold), file=self.log_file)
202 203 204 205 206 207 208 209 210 211 212 213

        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

214 215 216 217 218 219 220 221 222
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
223 224 225 226 227
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

228
        if self._eval:
229 230 231 232 233 234 235 236 237
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
238 239
            # number of false rejects
            eval_fnm = int(round(eval_fnmr * eval_nc))
240 241

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
242 243 244 245
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
                                               eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
                                                eval_fnmr, eval_fnm, eval_nc)
246 247 248 249 250 251 252 253 254 255 256 257

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
258 259 260 261 262 263 264 265

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
266

267 268 269 270
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
271

272 273
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
274
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
275
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
276
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
277
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
278 279 280 281
        self._disp_legend = True if 'disp_legend' not in ctx.meta else\
            ctx.meta['disp_legend']
        self._legend_loc = None if 'legend_loc' not in ctx.meta else\
            ctx.meta['legend_loc']
282 283 284 285 286 287
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
        elif self._axlim is not None:
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
288
        self._clayout = None if 'clayout' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
289
            ctx.meta['clayout']
290
        self._far_at = None if 'lines_at' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
291
            ctx.meta['lines_at']
292 293 294 295
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
296
        self._print_fn = True if 'show_fn' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
297
            ctx.meta['show_fn']
298
        self._x_rotation = None if 'x_rotation' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
299
            ctx.meta['x_rotation']
300 301
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
302
        self._nb_figs = 2 if self._eval and self._split else 1
303
        self._colors = utils.get_colors(self.n_systems)
304
        self._line_linestyles = False if 'line_linestyles' not in ctx.meta else \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
305 306 307
            ctx.meta['line_linestyles']
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
308
        self._states = ['Development', 'Evaluation']
309 310
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
311
            ctx.meta['x_label']
312
        self._y_label = None if 'y_label' not in ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
313
            ctx.meta['y_label']
314 315 316 317 318 319 320 321 322 323
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
324
            self._ctx.meta else PdfPages(self._output)
325

326
        for i in range(self._nb_figs):
327
            fs = None if 'figsize' not in self._ctx.meta else\
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
328
                self._ctx.meta['figsize']
329 330
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
331
            fig.clear()
332 333

    def end_process(self):
334 335
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
336
        # draw vertical lines
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
356
        # only for plots
357 358 359 360
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
361
                if not self._eval:
362
                    title += (" (%s)" % self._states[0])
363
                elif self._split:
364
                    title += (" (%s)" % self._states[i])
Theophile GENTILHOMME's avatar
Theophile GENTILHOMME committed
365
                mpl.title(title if self._title.replace(' ', '') else '')
366 367 368
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
369 370
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
371
                self._set_axis()
372 373 374
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
375
        # do not want to close PDF when running evaluate
376 377 378 379
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
380
    # common protected functions
381 382

    def _label(self, base, name, idx):
383 384
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
385
        if self.n_systems > 1:
386 387 388
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

389
    def _set_axis(self):
390 391
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
392

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
393

394
class Roc(PlotBase):
395
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
396

397 398
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
399 400
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
401
        self._y_label = self._y_label or "1 - False Negative Rate"
402
        self._legend_loc = self._legend_loc or 4
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
403
        # custom defaults
404
        if self._axlim is None:
405 406 407 408
            self._axlim = [1e-4, 1.0, 0, 1.0]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig)
409

410
    def compute(self, idx, input_scores, input_names):
411 412
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
413 414
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
415 416
        dev_file = input_names[0]
        if self._eval:
417
            eval_neg, eval_pos = neg_list[1], pos_list[1]
418 419
            eval_file = input_names[1]

420
        mpl.figure(1)
421
        if self._eval:
422 423
            plot.roc_for_far(
                dev_neg, dev_pos,
424 425
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx], linestyle=self._linestyles[idx],
426
                label=self._label('development', dev_file, idx)
427 428 429 430
            )
            if self._split:
                mpl.figure(2)

431
            linestyle = '--' if not self._split else self._linestyles[idx]
432
            plot.roc_for_far(
433 434 435
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx],
436
                label=self._label('eval', eval_file, idx)
437
            )
438
            if self._far_at is not None:
439
                from .. import farfrr
440
                for line in self._far_at:
441
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
442 443
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
444
                    eval_fnmr = 1 - eval_fnmr
445 446
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
447
        else:
448 449
            plot.roc_for_far(
                dev_neg, dev_pos,
450 451
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx], linestyle=self._linestyles[idx],
452
                label=self._label('development', dev_file, idx)
453 454
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
455

456 457
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
458

459 460
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
461
        self._title = self._title or 'DET'
462 463
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
464
        self._legend_loc = self._legend_loc or 1
465 466
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
467
        # custom defaults here
468 469
        if self._x_rotation is None:
            self._x_rotation = 50
470

471 472 473 474 475 476
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

477
    def compute(self, idx, input_scores, input_names):
478 479
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
480 481
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
482 483
        dev_file = input_names[0]
        if self._eval:
484
            eval_neg, eval_pos = neg_list[1], pos_list[1]
485 486
            eval_file = input_names[1]

487
        mpl.figure(1)
488
        if self._eval and eval_neg is not None:
489 490
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
491
                linestyle=self._linestyles[idx],
492
                label=self._label('development', dev_file, idx)
493 494 495
            )
            if self._split:
                mpl.figure(2)
496
            linestyle = '--' if not self._split else self._linestyles[idx]
497
            plot.det(
498
                eval_neg, eval_pos, self._points, color=self._colors[idx],
499
                linestyle=linestyle,
500
                label=self._label('eval', eval_file, idx)
501
            )
502 503 504 505
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
506 507
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
508 509 510
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
511 512 513
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
514
                linestyle=self._linestyles[idx],
515
                label=self._label('development', dev_file, idx)
516 517
            )

518
    def _set_axis(self):
519
        plot.det_axis(self._axlim)
520

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
521

522 523
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
524

525 526
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
527
        if self._min_arg != 2:
528
            raise click.UsageError("EPC requires dev and eval score files")
529
        self._title = self._title or 'EPC'
530 531
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
532
        self._legend_loc = self._legend_loc or 9
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
533
        self._eval = True  # always eval data with EPC
534
        self._split = False
535
        self._nb_figs = 1
536
        self._far_at = None
537

538
    def compute(self, idx, input_scores, input_names):
539
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
540 541
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
542 543
        dev_file = input_names[0]
        if self._eval:
544
            eval_neg, eval_pos = neg_list[1], pos_list[1]
545 546
            eval_file = input_names[1]

547
        plot.epc(
548
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
549
            color=self._colors[idx], linestyle=self._linestyles[idx],
550
            label=self._label(
551
                'curve', dev_file + "_" + eval_file, idx
552
            )
553 554
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
555

556
class Hist(PlotBase):
557
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
558

559 560
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
561
        self._nbins = [] if 'n_bins' not in ctx.meta else ctx.meta['n_bins']
562
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
563
        if self._thres is not None and len(self._thres) != self.n_systems:
564
            if len(self._thres) == 1:
565
                self._thres = self._thres * self.n_systems
566 567
            else:
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
568
                    '#thresholds must be the same as #systems (%d)'
569
                    % self.n_systems
570
                )
571
        self._criterion = None if 'criterion' not in ctx.meta else \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
572
            ctx.meta['criterion']
573 574 575
        self._nrows = 1 if 'n_row' not in ctx.meta else ctx.meta['n_row']
        self._ncols = 1 if 'n_col' not in ctx.meta else ctx.meta['n_col']
        self._nlegends = 10 if 'legends_ncol' not in ctx.meta else \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
576
            ctx.meta['legends_ncol']
577
        self._step_print = int(self._nrows * self._ncols)
578
        self._title_base = 'Scores'
579
        self._y_label = 'Probability density'
580
        self._x_label = 'Scores values'
581 582
        self._end_setup_plot = False

583
    def compute(self, idx, input_scores, input_names):
584
        ''' Draw histograms of negative and positive scores.'''
585
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
586
            self._get_neg_pos_thres(idx, input_scores, input_names)
587 588
        dev_file = input_names[0]
        eval_file = None if len(input_names) != 2 else input_names[1]
589 590 591 592 593 594 595 596 597
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        neg = eval_neg if eval_neg is not None else dev_neg
        pos = eval_pos if eval_pos is not None else dev_pos
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
598
        # rest to be printed
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
599 600
        rest_print = self.n_systems - \
            int(idx / self._step_print) * self._step_print
601
        if n + self._ncols >= min(self._step_print, rest_print):
602 603 604
            axis.set_xlabel(self._x_label)
        axis.set_title(self._get_title(idx, dev_file, eval_file))
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
605
            '' if self._criterion is None else
606 607
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
608
        self._lines(threshold, label, neg, pos, idx)
609
        if sub_plot_idx == 1:
610
            self._plot_legends()
611 612 613 614 615
        if self._step_print == sub_plot_idx or idx == self.n_systems - 1:
            mpl.tight_layout()
            self._pdf_page.savefig(mpl.gcf(), bbox_inches="tight")
            mpl.clf()
            mpl.figure()
616

617
    def _get_title(self, idx, dev_file, eval_file):
618
        title = self._legends[idx] if self._legends is not None else None
619
        title = title or self._title or self._title_base
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
620 621
        title = '' if self._title is not None and not self._title.replace(
            ' ', '') else title
622
        return title or ''
623 624 625 626 627 628 629 630

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
631 632 633 634 635
        if self._disp_legend:
            mpl.gcf().legend(
                lines, labels, fontsize=6, loc=self._legend_loc, fancybox=True,
                framealpha=0.5, ncol=self._nlegends,
            )
636

637
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
638 639
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
640
        # can have several files for one system
641 642
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
643 644
        eval_neg = eval_pos = None
        if self._eval:
645
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
646 647
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

648
        threshold = utils.get_thres(
649
            self._criterion, dev_neg[0], dev_pos[0]
650
        ) if self._thres is None else self._thres[idx]
651
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
652

653
    def _density_hist(self, scores, n, **kwargs):
654
        n, bins, patches = mpl.hist(
655 656 657
            scores, density=True,
            bins='auto' if len(self._nbins) <= n else self._nbins[n],
            **kwargs
658 659 660
        )
        return (n, bins, patches)

661 662
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
663
        label = label or 'Threshold'
664 665 666 667
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
668
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
669 670 671 672

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
673 674
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
675 676
        )
        self._density_hist(
677 678
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
679
        )