figure.py 30 KB
Newer Older
1
2
3
4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
9
10
11
12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
15

16
17
18
19
20
21
22
23
24
25
26
27
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


28
29
30
31
32
33
34
35
36
37
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
38
39
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

40
    def __init__(self, ctx, scores, evaluation, func_load):
41
42
43
44
45
46
47
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
48
49
50
51
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
52
53
54
55
56
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
57
        self._legends = ctx.meta.get('legends')
58
        self._eval = evaluation
59
        self._min_arg = ctx.meta.get('min_arg', 1)
60
61
62
63
64
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
65
66
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
67
                                     "number of systems")
68
69
70
71
72
73
74
75
76
77

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
78
        # init matplotlib, log files, ...
79
        self.init_process()
80
81
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
82
83
        # Note that more than one dev or eval scores score can be passed to
        # each system
84
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
85
            # load scores for each system: get the corresponding arrays and
86
            # base-name of files
87
            input_scores, input_names = self._load_files(
88
89
90
91
92
93
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
94
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
95
96
            )
            self.compute(idx, input_scores, input_names)
97
        # setup final configuration, plotting properties, ...
98
99
        self.end_process()

100
    # protected functions that need to be overwritten
101
102
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
103
        before iterating through the different systems.
104
105
106
        Should reimplemented in derived classes"""
        pass

107
    # Main computations are done here in the subclasses
108
    @abstractmethod
109
    def compute(self, idx, input_scores, input_names):
110
        """Compute metrics or plots from the given scores provided by
111
112
113
114
115
116
117
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
118
119
120
121
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
122
123
        """
        pass
124
125
126
127
128
129
130
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

131

132
    # Things to do after the main iterative computations are done
133
134
    @abstractmethod
    def end_process(self):
135
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
136
        after iterating through the different systems.
137
        Should reimplemented in derived classes"""
138
139
        pass

140
    # common protected functions
141

142
143
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
144
145
146

        Returns
        -------
147
148
149
150
151
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
152
        '''
153
154
155
156
157
158
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
159

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
160

161
162
163
164
165
166
167
168
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
169

170
171
    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('FtA', 'FMR', 'FNMR', 'FAR', 'FRR', 'HTER')):
172
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
173
        self.names = names
174
175
176
177
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
178
        if self._thres is not None:
179
            if len(self._thres) == 1:
180
181
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
182
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
183
                    '#thresholds must be the same as #systems (%d)'
184
                    % len(self.n_systems)
185
                )
186
187
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
188
189
190
191
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

192
193
194
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

195
    def compute(self, idx, input_scores, input_names):
196
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
197
        given system inputs'''
198
199
200
201
202
203
204
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

205
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
206
            if self._thres is None else self._thres[idx]
207
        title = self._legends[idx] if self._legends is not None else None
208
        if self._thres is None:
209
            far_str = ''
210
            if self._criterion == 'far' and self._far is not None:
211
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
212
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
213
214
215
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
216
217
                       file=self.log_file)
        else:
218
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
219
                       "Development set `%s`: %e"
220
                       % (dev_file or title, threshold), file=self.log_file)
221
222
223
224
225
226
227
228
229
230
231
232

        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

233
234
235
236
237
238
239
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
240
241
242
243
244
245
        raws = [[self.names[0], dev_fta_str],
                [self.names[1], dev_fmr_str],
                [self.names[2], dev_fnmr_str],
                [self.names[3], dev_far_str],
                [self.names[4], dev_frr_str],
                [self.names[5], dev_hter_str]]
246

247
        if self._eval:
248
249
250
251
252
253
254
255
256
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
257
258
            # number of false rejects
            eval_fnm = int(round(eval_fnmr * eval_nc))
259
260

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
261
262
263
264
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
                                               eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
                                                eval_fnmr, eval_fnm, eval_nc)
265
266
267
268
269
270
271
272
273
274
275
276

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
277
278
279
280
281
282
283
284

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
285

286
287
288
289
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
290

291
292
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
293
294
295
296
297
298
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
299
300
301
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
302
        elif self._axlim is not None and self._axlim[0] is not None:
303
304
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
305
306
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
307
308
309
310
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
311
312
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
313
314
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
315
        self._nb_figs = 2 if self._eval and self._split else 1
316
        self._colors = utils.get_colors(self.n_systems)
317
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
318
319
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
320
        self._titles = ctx.meta.get('titles', []) * 2
321
322
323
324
325
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

326
327
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
328
329
330
331
332
333
334
335
336
337
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
338
            self._ctx.meta else PdfPages(self._output)
339

340
        for i in range(self._nb_figs):
341
            fs = self._ctx.meta.get('figsize')
342
343
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
344
            fig.clear()
345
346

    def end_process(self):
347
348
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
349
        # draw vertical lines
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
369
        # only for plots
370
371
372
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
373
374
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
375
376
377
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
378
379
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
380
                self._set_axis()
381
382
383
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
384
        # do not want to close PDF when running evaluate
385
386
387
388
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
389
    # common protected functions
390
391

    def _label(self, base, name, idx):
392
393
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
394
        if self.n_systems > 1:
395
396
397
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

398
    def _set_axis(self):
399
        if self._axlim is not None:
400
            mpl.axis(self._axlim)
401

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
402

403
class Roc(PlotBase):
404
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
405

406
407
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
408
        self._titles = self._titles or ['ROC dev', 'ROC eval']
409
        self._x_label = self._x_label or 'False Positive Rate'
410
        self._y_label = self._y_label or "1 - False Negative Rate"
411
412
413
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
414
        # custom defaults
415
        if self._axlim is None:
416
            self._axlim = [None, None, -0.05, 1.05]
417

418
    def compute(self, idx, input_scores, input_names):
419
420
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
421
422
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
423
424
        dev_file = input_names[0]
        if self._eval:
425
            eval_neg, eval_pos = neg_list[1], pos_list[1]
426
427
            eval_file = input_names[1]

428
        mpl.figure(1)
429
        if self._eval:
430
431
            plot.roc_for_far(
                dev_neg, dev_pos,
432
                far_values=plot.log_values(self._min_dig or -4),
433
                CAR=self._semilogx,
434
                color=self._colors[idx], linestyle=self._linestyles[idx],
435
                label=self._label('dev', dev_file, idx)
436
437
438
439
            )
            if self._split:
                mpl.figure(2)

440
            linestyle = '--' if not self._split else self._linestyles[idx]
441
            plot.roc_for_far(
442
443
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
444
                CAR=self._semilogx,
445
                color=self._colors[idx],
446
                label=self._label('eval', eval_file, idx)
447
            )
448
            if self._far_at is not None:
449
                from .. import farfrr
450
                for line in self._far_at:
451
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
452
453
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
454
                    eval_fnmr = 1 - eval_fnmr
455
456
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
457
        else:
458
459
            plot.roc_for_far(
                dev_neg, dev_pos,
460
                far_values=plot.log_values(self._min_dig or -4),
461
                CAR=self._semilogx,
462
                color=self._colors[idx], linestyle=self._linestyles[idx],
463
                label=self._label('dev', dev_file, idx)
464
465
            )

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
466

467
468
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
469

470
471
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
472
        self._titles = self._titles or ['DET dev', 'DET eval']
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
473
474
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
475
        self._legend_loc = self._legend_loc or 'upper right'
476
477
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
478
        # custom defaults here
479
480
        if self._x_rotation is None:
            self._x_rotation = 50
481

482
483
484
485
486
487
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

488
    def compute(self, idx, input_scores, input_names):
489
490
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
491
492
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
493
494
        dev_file = input_names[0]
        if self._eval:
495
            eval_neg, eval_pos = neg_list[1], pos_list[1]
496
497
            eval_file = input_names[1]

498
        mpl.figure(1)
499
        if self._eval and eval_neg is not None:
500
501
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
502
                linestyle=self._linestyles[idx],
503
                label=self._label('development', dev_file, idx)
504
505
506
            )
            if self._split:
                mpl.figure(2)
507
            linestyle = '--' if not self._split else self._linestyles[idx]
508
            plot.det(
509
                eval_neg, eval_pos, self._points, color=self._colors[idx],
510
                linestyle=linestyle,
511
                label=self._label('eval', eval_file, idx)
512
            )
513
514
515
516
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
517
518
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
519
520
521
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
522
523
524
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
525
                linestyle=self._linestyles[idx],
526
                label=self._label('development', dev_file, idx)
527
528
            )

529
    def _set_axis(self):
530
        plot.det_axis(self._axlim)
531

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
532

533
534
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
535

536
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
537
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
538
        if self._min_arg != 2:
539
            raise click.UsageError("EPC requires dev and eval score files")
540
        self._titles = self._titles or ['EPC'] * 2
541
        self._x_label = self._x_label or r'$\alpha$'
542
        self._y_label = self._y_label or hter + ' (%)'
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
543
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
544
        self._eval = True  # always eval data with EPC
545
        self._split = False
546
        self._nb_figs = 1
547
        self._far_at = None
548

549
    def compute(self, idx, input_scores, input_names):
550
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
551
552
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
553
554
        dev_file = input_names[0]
        if self._eval:
555
            eval_neg, eval_pos = neg_list[1], pos_list[1]
556
557
            eval_file = input_names[1]

558
        plot.epc(
559
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
560
            color=self._colors[idx], linestyle=self._linestyles[idx],
561
            label=self._label(
562
                'curve', dev_file + "_" + eval_file, idx
563
            )
564
565
        )

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
566

567
class Hist(PlotBase):
568
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
569

570
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
571
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
572
573
574
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
575
            self._nbins, nhist_per_system, 'n_bins',
576
            'histograms')
577
        self._thres = ctx.meta.get('thres')
578
579
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
580
        self._criterion = ctx.meta.get('criterion')
581
        # no vertical (threshold) is displayed
582
        self._no_line = ctx.meta.get('no_line', False)
583
        # subplot grid
584
585
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
586
        # do not display dev histo
587
        self._hide_dev = ctx.meta.get('hide_dev', False)
588
589
590
        if self._hide_dev and not self._eval:
            raise click.BadParameter("You can only use --hide-dev along with --eval")

591
        # dev hist are displayed next to eval hist
592
        self._ncols *= 1 if self._hide_dev or not self._eval else 2
593
        self._nlegends = ctx.meta.get('legends_ncol', 3)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
594
        self._legend_loc = self._legend_loc or 'upper center'
595
        # number of subplot on one page
596
        self._step_print = int(self._nrows * self._ncols)
597
        self._title_base = 'Scores'
598
        self._y_label = 'Probability density'
599
        self._x_label = 'Score values'
600
        self._end_setup_plot = False
601
602
        if self._legends is not None and len(self._legends) == self.n_systems \
           and not self._hide_dev:
603
            # use same legend for dev and eval if needed
604
605
            self._legends = [x for pair in zip(self._legends,self._legends)
                             for x in pair]
606

607
    def compute(self, idx, input_scores, input_names):
608
        ''' Draw histograms of negative and positive scores.'''
609
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
610
            self._get_neg_pos_thres(idx, input_scores, input_names)
611
612
613
        idx *= 1 if self._hide_dev or not self._eval else 2

        if not self._hide_dev or not self._eval:
614
615
            self._print_subplot(idx, dev_neg, dev_pos, threshold,
                                not self._no_line, False)
616
617
618
619

        idx += 1 if self._eval and not self._hide_dev else 0
        if self._eval:
            self._print_subplot(idx, eval_neg, eval_pos, threshold,
620
                                not self._no_line, True)
621

622
    def _print_subplot(self, idx, neg, pos, threshold, draw_line, evaluation):
623
        ''' print a subplot for the given score and subplot index'''
624
625
626
627
628
629
630
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
631
        # rest to be printed
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
632
633
        rest_print = self.n_systems - \
            int(idx / self._step_print) * self._step_print
634
        if n + self._ncols >= min(self._step_print, rest_print):
635
            axis.set_xlabel(self._x_label)
636
        dflt_title = "Eval scores" if evaluation else "Dev scores"
637
        axis.set_title(self._get_title(idx, dflt_title))
638
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
639
            '' if self._criterion is None else
640
641
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
642
        if draw_line:
643
            self._lines(threshold, label, neg, pos, idx)
644

645
        mult = 2 if self._eval and not self._hide_dev else 1
646
647
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
648
        if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
649
650
            # print legend on the page
            self.plot_legends()
651
            mpl.tight_layout()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
652
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
653
654
            mpl.clf()
            mpl.figure()
655

656
    def _get_title(self, idx, dflt=None):
657
        ''' Get the histo title for the given idx'''
658
        title = self._legends[idx] if self._legends is not None \
659
            and idx < len(self._legends) else dflt
660
        title = title or self._title_base
661
662
        title = '' if title is not None and not title.replace(
            ' ', '') else title
663
        return title or ''
664

665
666
    def plot_legends(self):
        ''' Print legend on current page'''
667
668
669
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
670
671
672
673
674
675
676
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

677
678
        if self._disp_legend:
            mpl.gcf().legend(
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
679
                lines, labels, loc=self._legend_loc, fancybox=True,
680
                framealpha=0.5, ncol=self._nlegends,
681
                bbox_to_anchor=(0.55, 1.1),
682
            )
683

684
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
685
        ''' Get scores and threshod for the given system at index idx'''
686
687
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
688
689
690
691
692
693
694
695
696
697
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
698
        # can have several files for one system
699
700
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
701
702
        eval_neg = eval_pos = None
        if self._eval:
703
704
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
705

706
        threshold = utils.get_thres(
707
            self._criterion, dev_neg[0], dev_pos[0]
708
        ) if self._thres is None else self._thres[idx]
709
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
710

711
    def _density_hist(self, scores, n, **kwargs):
712
        ''' Plots one density histo'''
713
        n, bins, patches = mpl.hist(
714
            scores, density=True,
715
            bins=self._nbins[n],
716
            **kwargs
717
718
719
        )
        return (n, bins, patches)

720
721
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
722
        ''' Plots vertical line at threshold '''
723
        label = label or 'Threshold'
724
725
726
727
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
728
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
729
730

    def _setup_hist(self, neg, pos):
731
732
733
734
735
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
736
        self._density_hist(
737
738
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
739
740
        )
        self._density_hist(
741
742
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
743
        )