figure.py 29 KB
Newer Older
1
2
3
4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
9
10
11
12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
15

16
17
18
19
20
21
22
23
24
25
26
27
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


28
29
30
31
32
33
34
35
36
37
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
38
39
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

40
    def __init__(self, ctx, scores, evaluation, func_load):
41
42
43
44
45
46
47
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
48
49
50
51
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
52
53
54
55
56
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
57
        self._legends = ctx.meta.get('legends')
58
        self._eval = evaluation
59
        self._min_arg = ctx.meta.get('min_arg', 1)
60
61
62
63
64
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
65
66
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
67
                                     "number of systems")
68
69
70
71
72
73
74
75
76
77

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
78
        # init matplotlib, log files, ...
79
        self.init_process()
80
81
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
82
83
        # Note that more than one dev or eval scores score can be passed to
        # each system
84
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
85
            # load scores for each system: get the corresponding arrays and
86
            # base-name of files
87
            input_scores, input_names = self._load_files(
88
89
90
91
92
93
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
94
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
95
96
            )
            self.compute(idx, input_scores, input_names)
97
        # setup final configuration, plotting properties, ...
98
99
        self.end_process()

100
    # protected functions that need to be overwritten
101
102
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
103
        before iterating through the different systems.
104
105
106
        Should reimplemented in derived classes"""
        pass

107
    # Main computations are done here in the subclasses
108
    @abstractmethod
109
    def compute(self, idx, input_scores, input_names):
110
        """Compute metrics or plots from the given scores provided by
111
112
113
114
115
116
117
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
118
119
120
121
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
122
123
124
        """
        pass

125
    # Things to do after the main iterative computations are done
126
127
    @abstractmethod
    def end_process(self):
128
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
129
        after iterating through the different systems.
130
        Should reimplemented in derived classes"""
131
132
        pass

133
    # common protected functions
134

135
136
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
137
138
139

        Returns
        -------
140
141
142
143
144
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
145
        '''
146
147
148
149
150
151
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
152

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
153

154
155
156
157
158
159
160
161
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
162

163
164
    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('FtA', 'FMR', 'FNMR', 'FAR', 'FRR', 'HTER')):
165
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
166
        self.names = names
167
168
169
170
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
171
        if self._thres is not None:
172
            if len(self._thres) == 1:
173
174
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
175
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
176
                    '#thresholds must be the same as #systems (%d)'
177
                    % len(self.n_systems)
178
                )
179
180
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
181
182
183
184
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

185
186
187
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

188
    def compute(self, idx, input_scores, input_names):
189
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
190
        given system inputs'''
191
192
193
194
195
196
197
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

198
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
199
            if self._thres is None else self._thres[idx]
200
        title = self._legends[idx] if self._legends is not None else None
201
        if self._thres is None:
202
            far_str = ''
203
            if self._criterion == 'far' and self._far is not None:
204
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
205
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
206
207
208
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
209
210
                       file=self.log_file)
        else:
211
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
212
                       "Development set `%s`: %e"
213
                       % (dev_file or title, threshold), file=self.log_file)
214
215
216
217
218
219
220
221
222
223
224
225

        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

226
227
228
229
230
231
232
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
233
234
235
236
237
238
        raws = [[self.names[0], dev_fta_str],
                [self.names[1], dev_fmr_str],
                [self.names[2], dev_fnmr_str],
                [self.names[3], dev_far_str],
                [self.names[4], dev_frr_str],
                [self.names[5], dev_hter_str]]
239

240
        if self._eval:
241
242
243
244
245
246
247
248
249
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
250
251
            # number of false rejects
            eval_fnm = int(round(eval_fnmr * eval_nc))
252
253

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
254
255
256
257
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
                                               eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
                                                eval_fnmr, eval_fnm, eval_nc)
258
259
260
261
262
263
264
265
266
267
268
269

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
270
271
272
273
274
275
276
277

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
278

279
280
281
282
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
283

284
285
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
286
287
288
289
290
291
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
292
293
294
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
295
        elif self._axlim is not None and self._axlim[0] is not None:
296
297
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
298
299
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
300
301
302
303
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
304
305
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
306
307
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
308
        self._nb_figs = 2 if self._eval and self._split else 1
309
        self._colors = utils.get_colors(self.n_systems)
310
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
311
312
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
313
        self._titles = ctx.meta.get('titles', []) * 2
314
315
316
317
318
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

319
320
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
321
322
323
324
325
326
327
328
329
330
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
331
            self._ctx.meta else PdfPages(self._output)
332

333
        for i in range(self._nb_figs):
334
            fs = self._ctx.meta.get('figsize')
335
336
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
337
            fig.clear()
338
339

    def end_process(self):
340
341
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
342
        # draw vertical lines
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
362
        # only for plots
363
364
365
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
366
367
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
368
369
370
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
371
372
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
373
                self._set_axis()
374
375
376
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
377
        # do not want to close PDF when running evaluate
378
379
380
381
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
382
    # common protected functions
383
384

    def _label(self, base, name, idx):
385
386
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
387
        if self.n_systems > 1:
388
389
390
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

391
    def _set_axis(self):
392
        if self._axlim is not None:
393
            mpl.axis(self._axlim)
394

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
395

396
class Roc(PlotBase):
397
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
398

399
400
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
401
        self._titles = self._titles or ['ROC dev', 'ROC eval']
402
        self._x_label = self._x_label or 'False Positive Rate'
403
        self._y_label = self._y_label or "1 - False Negative Rate"
404
405
406
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
407
        # custom defaults
408
        if self._axlim is None:
409
            self._axlim = [None, None, -0.05, 1.05]
410

411
    def compute(self, idx, input_scores, input_names):
412
413
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
414
415
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
416
417
        dev_file = input_names[0]
        if self._eval:
418
            eval_neg, eval_pos = neg_list[1], pos_list[1]
419
420
            eval_file = input_names[1]

421
        mpl.figure(1)
422
        if self._eval:
423
424
            plot.roc_for_far(
                dev_neg, dev_pos,
425
                far_values=plot.log_values(self._min_dig or -4),
426
                CAR=self._semilogx,
427
                color=self._colors[idx], linestyle=self._linestyles[idx],
428
                label=self._label('dev', dev_file, idx)
429
430
431
432
            )
            if self._split:
                mpl.figure(2)

433
            linestyle = '--' if not self._split else self._linestyles[idx]
434
            plot.roc_for_far(
435
436
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
437
                CAR=self._semilogx,
438
                color=self._colors[idx],
439
                label=self._label('eval', eval_file, idx)
440
            )
441
            if self._far_at is not None:
442
                from .. import farfrr
443
                for line in self._far_at:
444
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
445
446
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
447
                    eval_fnmr = 1 - eval_fnmr
448
449
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
450
        else:
451
452
            plot.roc_for_far(
                dev_neg, dev_pos,
453
                far_values=plot.log_values(self._min_dig or -4),
454
                CAR=self._semilogx,
455
                color=self._colors[idx], linestyle=self._linestyles[idx],
456
                label=self._label('dev', dev_file, idx)
457
458
            )

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
459

460
461
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
462

463
464
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
465
        self._titles = self._titles or ['DET dev', 'DET eval']
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
466
467
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
468
        self._legend_loc = self._legend_loc or 'upper right'
469
470
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
471
        # custom defaults here
472
473
        if self._x_rotation is None:
            self._x_rotation = 50
474

475
476
477
478
479
480
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

481
    def compute(self, idx, input_scores, input_names):
482
483
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
484
485
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
486
487
        dev_file = input_names[0]
        if self._eval:
488
            eval_neg, eval_pos = neg_list[1], pos_list[1]
489
490
            eval_file = input_names[1]

491
        mpl.figure(1)
492
        if self._eval and eval_neg is not None:
493
494
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
495
                linestyle=self._linestyles[idx],
496
                label=self._label('development', dev_file, idx)
497
498
499
            )
            if self._split:
                mpl.figure(2)
500
            linestyle = '--' if not self._split else self._linestyles[idx]
501
            plot.det(
502
                eval_neg, eval_pos, self._points, color=self._colors[idx],
503
                linestyle=linestyle,
504
                label=self._label('eval', eval_file, idx)
505
            )
506
507
508
509
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
510
511
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
512
513
514
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
515
516
517
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
518
                linestyle=self._linestyles[idx],
519
                label=self._label('development', dev_file, idx)
520
521
            )

522
    def _set_axis(self):
523
        plot.det_axis(self._axlim)
524

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
525

526
527
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
528

529
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
530
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
531
        if self._min_arg != 2:
532
            raise click.UsageError("EPC requires dev and eval score files")
533
        self._titles = self._titles or ['EPC'] * 2
534
        self._x_label = self._x_label or r'$\alpha$'
535
        self._y_label = self._y_label or hter + ' (%)'
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
536
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
537
        self._eval = True  # always eval data with EPC
538
        self._split = False
539
        self._nb_figs = 1
540
        self._far_at = None
541

542
    def compute(self, idx, input_scores, input_names):
543
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
544
545
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
546
547
        dev_file = input_names[0]
        if self._eval:
548
            eval_neg, eval_pos = neg_list[1], pos_list[1]
549
550
            eval_file = input_names[1]

551
        plot.epc(
552
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
553
            color=self._colors[idx], linestyle=self._linestyles[idx],
554
            label=self._label(
555
                'curve', dev_file + "_" + eval_file, idx
556
            )
557
558
        )

Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
559

560
class Hist(PlotBase):
561
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
562

563
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
564
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
565
566
567
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
568
            self._nbins, nhist_per_system, 'n_bins',
569
            'histograms')
570
        self._thres = ctx.meta.get('thres')
571
572
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
573
        self._criterion = ctx.meta.get('criterion')
574
        # no vertical (threshold) is displayed
575
        self._no_line = ctx.meta.get('no_line', False)
576
        # subplot grid
577
578
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
579
        # do not display dev histo
580
581
582
        self._hide_dev = ctx.meta.get('hide_dev', False)
        # dev hist are displayed next to eval hist
        self._ncols *= 1 if self._hide_dev else 2
583
        self._nlegends = ctx.meta.get('legends_ncol', 3)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
584
        self._legend_loc = self._legend_loc or 'upper center'
585
        # number of subplot on one page
586
        self._step_print = int(self._nrows * self._ncols)
587
        self._title_base = 'Scores'
588
        self._y_label = 'Probability density'
589
        self._x_label = 'Scores values'
590
        self._end_setup_plot = False
591
592
        if self._legends is not None and len(self._legends) == self.n_systems \
           and not self._hide_dev:
593
            # use same legend for dev and eval if needed
594
595
            self._legends = [x for pair in zip(self._legends,self._legends)
                             for x in pair]
596

597
    def compute(self, idx, input_scores, input_names):
598
        ''' Draw histograms of negative and positive scores.'''
599
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
600
            self._get_neg_pos_thres(idx, input_scores, input_names)
601
602
603
604
605
606
607
608
609
610
611
612
        idx *= 1 if self._hide_dev or not self._eval else 2

        if not self._hide_dev or not self._eval:
            self._print_subplot(idx, dev_neg, dev_pos, threshold, False,
                                dflt_title="Dev scores")

        idx += 1 if self._eval and not self._hide_dev else 0
        if self._eval:
            self._print_subplot(idx, eval_neg, eval_pos, threshold,
                                not self._no_line, dflt_title="Eval scores")

    def _print_subplot(self, idx, neg, pos, threshold, draw_line, dflt_title):
613
        ''' print a subplot for the given score and subplot index'''
614
615
616
617
618
619
620
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
621
        # rest to be printed
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
622
623
        rest_print = self.n_systems - \
            int(idx / self._step_print) * self._step_print
624
        if n + self._ncols >= min(self._step_print, rest_print):
625
            axis.set_xlabel(self._x_label)
626
        axis.set_title(self._get_title(idx, dflt_title))
627
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
628
            '' if self._criterion is None else
629
630
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
631
        if draw_line:
632
            self._lines(threshold, label, neg, pos, idx)
633

634
        mult = 2 if self._eval and not self._hide_dev else 1
635
636
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
637
        if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
638
639
            # print legend on the page
            self.plot_legends()
640
            mpl.tight_layout()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
641
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
642
643
            mpl.clf()
            mpl.figure()
644

645
    def _get_title(self, idx, dflt=None):
646
        ''' Get the histo title for the given idx'''
647
        title = self._legends[idx] if self._legends is not None \
648
            and idx < len(self._legends) else dflt
649
        title = title or self._title_base
650
651
        title = '' if title is not None and not title.replace(
            ' ', '') else title
652
        return title or ''
653

654
655
    def plot_legends(self):
        ''' Print legend on current page'''
656
657
658
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
659
660
661
662
663
664
665
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

666
667
        if self._disp_legend:
            mpl.gcf().legend(
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
668
                lines, labels, loc=self._legend_loc, fancybox=True,
669
                framealpha=0.5, ncol=self._nlegends,
670
                bbox_to_anchor=(0.55, 1.1),
671
            )
672

673
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
674
        ''' Get scores and threshod for the given system at index idx'''
675
676
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
Amir MOHAMMADI's avatar
lint    
Amir MOHAMMADI committed
677
        # can have several files for one system
678
679
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
680
681
        eval_neg = eval_pos = None
        if self._eval:
682
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
683
684
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

685
        threshold = utils.get_thres(
686
            self._criterion, dev_neg[0], dev_pos[0]
687
        ) if self._thres is None else self._thres[idx]
688
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
689

690
    def _density_hist(self, scores, n, **kwargs):
691
        ''' Plots one density histo'''
692
        n, bins, patches = mpl.hist(
693
            scores, density=True,
694
            bins=self._nbins[n],
695
            **kwargs
696
697
698
        )
        return (n, bins, patches)

699
700
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
701
        ''' Plots vertical line at threshold '''
702
        label = label or 'Threshold'
703
704
705
706
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
707
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
708
709

    def _setup_hist(self, neg, pos):
710
711
712
713
714
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
715
        self._density_hist(
716
717
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
718
719
        )
        self._density_hist(
720
721
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
722
        )