figure.py 27.1 KB
Newer Older
1
2
3
4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
9
10
11
12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14
15
16
17
18
19
20
21
22
23
24
25

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
26
    def __init__(self, ctx, scores, evaluation, func_load):
27
28
29
30
31
32
33
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
34
35
36
37
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
38
39
40
        func_load : Function that is used to load the input files
        """
        self._scores = scores
41
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
42
43
        self._ctx = ctx
        self.func_load = func_load
44
        self._legends = None if 'legends' not in ctx.meta else ctx.meta['legends']
45
46
47
48
49
50
51
        self._eval = evaluation
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
52
53
        if self._legends is not None and len(self._legends) != self.n_systems:
            raise click.BadParameter("Number of legends must be equal to the "
54
                                     "number of systems")
55
56
57
58
59
60
61
62
63
64

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
65
        # init matplotlib, log files, ...
66
        self.init_process()
67
68
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
69
70
        # Note that more than one dev or eval scores score can be passed to
        # each system
71
        for idx in range(self.n_systems):
72
73
            # load scores for each system: get the corresponding arrays and 
            # base-name of files
74
            input_scores, input_names = self._load_files(
75
76
77
78
79
80
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
81
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
82
83
            )
            self.compute(idx, input_scores, input_names)
84
        # setup final configuration, plotting properties, ...
85
86
        self.end_process()

87
    # protected functions that need to be overwritten
88
89
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
90
        before iterating through the different systems.
91
92
93
        Should reimplemented in derived classes"""
        pass

94
    # Main computations are done here in the subclasses
95
    @abstractmethod
96
    def compute(self, idx, input_scores, input_names):
97
        """Compute metrics or plots from the given scores provided by
98
99
100
101
102
103
104
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
105
106
107
108
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
109
110
111
        """
        pass

112
    # Things to do after the main iterative computations are done
113
114
    @abstractmethod
    def end_process(self):
115
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
116
        after iterating through the different systems.
117
        Should reimplemented in derived classes"""
118
119
        pass

120
    # common protected functions
121

122
123
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
124
125
126

        Returns
        -------
127
128
129
130
131
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
132
        '''
133
134
135
136
137
138
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
139
140
141
142
143
144
145
146
147

class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
148
149
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
150
151
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
152
153
        self._criterion = None if 'criterion' not in ctx.meta else \
        ctx.meta['criterion']
154
155
156
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
157
158
        if self._thres is not None :
            if len(self._thres) == 1:
159
160
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
161
162
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
163
                    % len(self.n_systems)
164
165
166
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
167
168
169
170
171
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

172
    def compute(self, idx, input_scores, input_names):
173
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
174
        given system inputs'''
175
176
177
178
179
180
181
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

182
        threshold = utils.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
183
                if self._thres is None else self._thres[idx]
184
        title = self._legends[idx] if self._legends is not None else None
185
        if self._thres is None:
186
            far_str = ''
187
            if self._criterion == 'far' and self._far is not None:
188
189
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
190
191
192
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
193
194
195
196
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
197
                       % (dev_file or title, threshold), file=self.log_file)
198
199
200
201
202
203
204
205
206
207
208
209
210


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

211
212
213
214
215
216
217
218
219
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
220
221
222
223
224
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

225
        if self._eval:
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
252
253
254
255
256
257
258
259
260
261
262
263

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
264
265
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
266
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
267
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
268
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
269
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
270
271
272
273
274
275
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
        elif self._axlim is not None:
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
276
277
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
278
279
280
281
282
283
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
284
        self._print_fn = True if 'show_fn' not in ctx.meta else\
285
        ctx.meta['show_fn']
286
287
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
288
289
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
290
        self._nb_figs = 2 if self._eval and self._split else 1
291
        self._colors = utils.get_colors(self.n_systems)
292
293
294
        self._line_linestyles = False if 'line_linestyles' not in ctx.meta else \
                ctx.meta['line_linestyles']
        self._linestyles = utils.get_linestyles(self.n_systems, self._line_linestyles)
295
        self._states = ['Development', 'Evaluation']
296
297
298
299
300
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
301
302
303
304
305
306
307
308
309
310
311
312
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

313
        for i in range(self._nb_figs):
314
315
316
317
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
318
            fig.clear()
319
320

    def end_process(self):
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
        #draw vertical lines
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
343
344
345
346
347
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
348
                if not self._eval:
349
                    title += (" (%s)" % self._states[0])
350
                elif self._split:
351
                    title += (" (%s)" % self._states[i])
Theophile GENTILHOMME's avatar
Theophile GENTILHOMME committed
352
                mpl.title(title if self._title.replace(' ', '') else '')
353
354
355
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
356
                mpl.legend(loc='best')
357
                self._set_axis()
358
359
360
361
362
363
364
365
366
367
368
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

    #common protected functions

    def _label(self, base, name, idx):
369
370
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
371
        if self.n_systems > 1:
372
373
374
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

375
    def _set_axis(self):
376
377
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
378

379
class Roc(PlotBase):
380
    ''' Handles the plotting of ROC'''
381
382
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
383
384
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
385
        self._y_label = self._y_label or "1 - False Negative Rate"
386
387
        #custom defaults
        if self._axlim is None:
388
389
390
391
            self._axlim = [1e-4, 1.0, 0, 1.0]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig)
392

393
    def compute(self, idx, input_scores, input_names):
394
395
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
396
397
398
399
400
401
402
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

403
        mpl.figure(1)
404
        if self._eval:
405
406
            plot.roc_for_far(
                dev_neg, dev_pos,
407
408
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx], linestyle=self._linestyles[idx],
409
                label=self._label('development', dev_file, idx)
410
411
412
413
            )
            if self._split:
                mpl.figure(2)

414
            linestyle = '--' if not self._split else self._linestyles[idx]
415
            plot.roc_for_far(
416
417
418
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx],
419
                label=self._label('eval', eval_file, idx)
420
            )
421
            if self._far_at is not None:
422
                from .. import farfrr
423
                for line in self._far_at:
424
425
426
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fnmr = 1 - eval_fnmr
427
428
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
429
        else:
430
431
            plot.roc_for_far(
                dev_neg, dev_pos,
432
433
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx], linestyle=self._linestyles[idx],
434
                label=self._label('development', dev_file, idx)
435
436
437
438
            )

class Det(PlotBase):
    ''' Handles the plotting of DET '''
439
440
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
441
        self._title = self._title or 'DET'
442
443
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
444
445
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
446
447
448
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
449

450
451
452
453
454
455
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

456
    def compute(self, idx, input_scores, input_names):
457
458
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
459
460
461
462
463
464
465
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

466
        mpl.figure(1)
467
        if self._eval and eval_neg is not None:
468
469
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
470
                linestyle=self._linestyles[idx],
471
                label=self._label('development', dev_file, idx)
472
473
474
            )
            if self._split:
                mpl.figure(2)
475
            linestyle = '--' if not self._split else self._linestyles[idx]
476
            plot.det(
477
                eval_neg, eval_pos, self._points, color=self._colors[idx],
478
                linestyle=linestyle,
479
                label=self._label('eval', eval_file, idx)
480
            )
481
482
483
484
485
486
487
488
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
489
490
491
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
492
                linestyle=self._linestyles[idx],
493
                label=self._label('development', dev_file, idx)
494
495
            )

496
    def _set_axis(self):
497
        plot.det_axis(self._axlim)
498
499
500

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
501
502
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
503
        if self._min_arg != 2:
504
            raise click.UsageError("EPC requires dev and eval score files")
505
        self._title = self._title or 'EPC'
506
507
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
508
        self._eval = True #always eval data with EPC
509
        self._split = False
510
        self._nb_figs = 1
511
        self._far_at = None
512

513
    def compute(self, idx, input_scores, input_names):
514
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
515
516
517
518
519
520
521
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

522
        plot.epc(
523
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
524
            color=self._colors[idx], linestyle=self._linestyles[idx],
525
            label=self._label(
526
                'curve', dev_file + "_" + eval_file, idx
527
            )
528
529
530
        )

class Hist(PlotBase):
531
    ''' Functional base class for histograms'''
532
533
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
534
        self._nbins = [] if 'n_bins' not in ctx.meta else ctx.meta['n_bins']
535
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
536
        if self._thres is not None and len(self._thres) != self.n_systems:
537
            if len(self._thres) == 1:
538
                self._thres = self._thres * self.n_systems
539
540
541
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
542
                    % self.n_systems
543
                )
544
545
        self._criterion = None if 'criterion' not in ctx.meta else \
        ctx.meta['criterion']
546
547
548
549
550
        self._nrows = 1 if 'n_row' not in ctx.meta else ctx.meta['n_row']
        self._ncols = 1 if 'n_col' not in ctx.meta else ctx.meta['n_col']
        self._nlegends = 10 if 'legends_ncol' not in ctx.meta else \
                ctx.meta['legends_ncol']
        self._step_print = int(self._nrows * self._ncols)
551
        self._title_base = 'Scores'
552
        self._y_label = 'Probability density'
553
        self._x_label = 'Scores values'
554
555
        self._end_setup_plot = False

556
    def compute(self, idx, input_scores, input_names):
557
        ''' Draw histograms of negative and positive scores.'''
558
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
559
560
561
        self._get_neg_pos_thres(idx, input_scores, input_names)
        dev_file = input_names[0]
        eval_file = None if len(input_names) != 2 else input_names[1]
562
563
564
565
566
567
568
569
570
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        neg = eval_neg if eval_neg is not None else dev_neg
        pos = eval_pos if eval_pos is not None else dev_pos
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
571
572
573
        # rest to be printed
        rest_print = self.n_systems - int(idx / self._step_print) * self._step_print
        if n + self._ncols >= min(self._step_print, rest_print):
574
575
576
577
578
579
            axis.set_xlabel(self._x_label)
        axis.set_title(self._get_title(idx, dev_file, eval_file))
        label = "%s threshold%s" % (
            '' if self._criterion is None else\
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
580
        self._lines(threshold, label, neg, pos, idx)
581
        if sub_plot_idx == 1:
582
            self._plot_legends()
583
584
585
586
587
        if self._step_print == sub_plot_idx or idx == self.n_systems - 1:
            mpl.tight_layout()
            self._pdf_page.savefig(mpl.gcf(), bbox_inches="tight")
            mpl.clf()
            mpl.figure()
588

589
    def _get_title(self, idx, dev_file, eval_file):
590
        title = self._legends[idx] if self._legends is not None else None
591
592
593
        title = title or self._title or self._title_base
        title = '' if self._title is not None and not self._title.replace(' ', '') else title
        return title or ''
594
595
596
597
598
599
600
601

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
602
        mpl.gcf().legend(
603
            lines, labels, fontsize=6, loc='upper center', fancybox=True,
604
605
            framealpha=0.5, ncol=self._nlegends,
        )
606

607
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
608
609
610
611
612
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
        #can have several files for one system
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
613
614
        eval_neg = eval_pos = None
        if self._eval:
615
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
616
617
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

618
        threshold = utils.get_thres(
619
            self._criterion, dev_neg[0], dev_pos[0]
620
        ) if self._thres is None else self._thres[idx]
621
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
622

623
    def _density_hist(self, scores, n, **kwargs):
624
        n, bins, patches = mpl.hist(
625
626
627
            scores, density=True,
            bins='auto' if len(self._nbins) <= n else self._nbins[n],
            **kwargs
628
629
630
        )
        return (n, bins, patches)

631
632
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
633
        label = label or 'Threshold'
634
635
636
637
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
638
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
639
640
641
642

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
643
644
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
645
646
        )
        self._density_hist(
647
648
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
649
        )