figure.py 26.7 KB
Newer Older
1
2
3
4
5
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
import sys
6
import os.path
7
8
9
10
11
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
12
from .. import (far_threshold, plot, utils, ppndf)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

LINESTYLES = [
    (0, ()),                    #solid
    (0, (4, 4)),                #dashed
    (0, (1, 5)),                #dotted
    (0, (3, 5, 1, 5)),          #dashdotted
    (0, (3, 5, 1, 5, 1, 5)),    #dashdotdotted
    (0, (5, 1)),                #densely dashed
    (0, (1, 1)),                #densely dotted
    (0, (3, 1, 1, 1)),          #densely dashdotted
    (0, (3, 1, 1, 1, 1, 1)),    #densely dashdotdotted
    (0, (5, 10)),               #loosely dashed
    (0, (3, 10, 1, 10)),        #loosely dashdotted
    (0, (3, 10, 1, 10, 1, 10)), #loosely dashdotdotted
    (0, (1, 10))                #loosely dotted
]

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
41
    def __init__(self, ctx, scores, evaluation, func_load):
42
43
44
45
46
47
48
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
49
50
51
52
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
53
54
55
        func_load : Function that is used to load the input files
        """
        self._scores = scores
56
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
57
58
        self._ctx = ctx
        self.func_load = func_load
59
        self._titles = None if 'titles' not in ctx.meta else ctx.meta['titles']
60
61
62
63
64
65
66
67
        self._eval = evaluation
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
        if self._titles is not None and len(self._titles) != self.n_systems:
68
69
            raise click.BadParameter("Number of titles must be equal to the "
                                     "number of systems")
70
71
72
73
74
75
76
77
78
79
80
81
82

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
        #init matplotlib, log files, ...
        self.init_process()
        #iterates through the different systems and feed `compute`
83
84
85
        #with the dev (and eval) scores of each system
        # Note that more than one dev or eval scores score can be passed to
        # each system
86
87
88
89
90
        for idx in range(self.n_systems):
            input_scores, input_names = self._load_files(
                self._scores[idx:(idx + self._min_arg)]
            )
            self.compute(idx, input_scores, input_names)
91
92
93
94
95
96
        #setup final configuration, plotting properties, ...
        self.end_process()

    #protected functions that need to be overwritten
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
97
        before iterating through the different systems.
98
99
100
101
102
        Should reimplemented in derived classes"""
        pass

    #Main computations are done here in the subclasses
    @abstractmethod
103
    def compute(self, idx, input_scores, input_names):
104
        """Compute metrics or plots from the given scores provided by
105
106
107
108
109
110
111
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
112
113
114
115
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
116
117
118
119
120
121
        """
        pass

    #Things to do after the main iterative computations are done
    @abstractmethod
    def end_process(self):
122
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
123
        after iterating through the different systems.
124
        Should reimplemented in derived classes"""
125
126
127
128
        pass

    #common protected functions

129
130
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
131
132
133

        Returns
        -------
134
135
136
137
138
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
139
        '''
140
141
142
143
144
145
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
146
147
148
149
150
151
152
153
154

class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
155
156
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
157
158
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
159
160
        self._criterion = None if 'criterion' not in ctx.meta else \
        ctx.meta['criterion']
161
162
163
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
164
165
        if self._thres is not None :
            if len(self._thres) == 1:
166
167
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
168
169
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
170
                    % len(self.n_systems)
171
172
173
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
174
175
176
177
178
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

179
    def compute(self, idx, input_scores, input_names):
180
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
181
        given system inputs'''
182
183
184
185
186
187
188
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

189
        threshold = utils.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
190
                if self._thres is None else self._thres[idx]
191
        title = self._titles[idx] if self._titles is not None else None
192
        if self._thres is None:
193
            far_str = ''
194
            if self._criterion == 'far' and self._far is not None:
195
196
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
197
198
199
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
200
201
202
203
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
204
                       % (dev_file or title, threshold), file=self.log_file)
205
206
207
208
209
210
211
212
213
214
215
216
217


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

218
219
220
221
222
223
224
225
226
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
227
228
229
230
231
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

232
        if self._eval:
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
259
260
261
262
263
264
265
266
267
268
269
270

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
271
272
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
273
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
274
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
275
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
276
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
277
278
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
279
280
281
282
283
284
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
285
        self._print_fn = True if 'show_fn' not in ctx.meta else\
286
        ctx.meta['show_fn']
287
288
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
289
290
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
291
        self._nb_figs = 2 if self._eval and self._split else 1
292
        self._colors = utils.get_colors(self.n_systems)
293
        self._states = ['Development', 'Evaluation']
294
295
296
297
298
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
299
300
301
302
303
304
305
306
307
308
309
310
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

311
        for i in range(self._nb_figs):
312
313
314
315
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
316
            fig.clear()
317
318

    def end_process(self):
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
        #draw vertical lines
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
341
342
343
344
345
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
346
                if not self._eval:
347
                    title += (" (%s)" % self._states[0])
348
                elif self._split:
349
350
351
352
353
                    title += (" (%s)" % self._states[i])
                mpl.title(title)
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
354
                mpl.legend(loc='best')
355
                self._set_axis()
356
357
358
359
360
361
362
363
364
365
366
367
368
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

    #common protected functions

    def _label(self, base, name, idx):
        if self._titles is not None and len(self._titles) > idx:
            return self._titles[idx]
369
        if self.n_systems > 1:
370
371
372
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

373
    def _set_axis(self):
374
375
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
376

377
class Roc(PlotBase):
378
    ''' Handles the plotting of ROC'''
379
380
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
381
382
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
383
        self._y_label = self._y_label or "1 - False Negative Rate"
384
385
386
        #custom defaults
        if self._axlim is None:
            self._axlim = [1e-4, 1.0, 1e-4, 1.0]
387

388
    def compute(self, idx, input_scores, input_names):
389
390
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
391
392
393
394
395
396
397
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

398
        mpl.figure(1)
399
        if self._eval:
400
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
401
402
            plot.roc_for_far(
                dev_neg, dev_pos,
403
                color=self._colors[idx], linestyle=linestyle,
404
                label=self._label('development', dev_file, idx)
405
406
407
408
409
410
            )
            linestyle = '--'
            if self._split:
                mpl.figure(2)
                linestyle = LINESTYLES[idx % 14]

411
412
            plot.roc_for_far(
                eval_neg, eval_pos,
413
                color=self._colors[idx], linestyle=linestyle,
414
                label=self._label('eval', eval_file, idx)
415
            )
416
            if self._far_at is not None:
417
                from .. import farfrr
418
                for line in self._far_at:
419
420
421
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fnmr = 1 - eval_fnmr
422
423
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
424
        else:
425
426
            plot.roc_for_far(
                dev_neg, dev_pos,
427
                color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
428
                label=self._label('development', dev_file, idx)
429
430
431
432
            )

class Det(PlotBase):
    ''' Handles the plotting of DET '''
433
434
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
435
        self._title = self._title or 'DET'
436
437
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
438
439
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
440
441
442
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
443

444
    def compute(self, idx, input_scores, input_names):
445
446
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
447
448
449
450
451
452
453
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

454
        mpl.figure(1)
455
        if self._eval and eval_neg is not None:
456
457
458
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
459
                linestyle=linestyle,
460
                label=self._label('development', dev_file, idx)
461
462
463
464
465
            )
            if self._split:
                mpl.figure(2)
            linestyle = '--' if not self._split else LINESTYLES[idx % 14]
            plot.det(
466
                eval_neg, eval_pos, self._points, color=self._colors[idx],
467
                linestyle=linestyle,
468
                label=self._label('eval', eval_file, idx)
469
            )
470
471
472
473
474
475
476
477
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
478
479
480
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
481
                linestyle=LINESTYLES[idx % 14],
482
                label=self._label('development', dev_file, idx)
483
484
            )

485
    def _set_axis(self):
486
487
488
489
        if self._axlim is not None and None not in self._axlim:
            plot.det_axis(self._axlim)
        else:
            plot.det_axis([0.01, 99, 0.01, 99])
490
491
492

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
493
494
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
495
        if self._min_arg != 2:
496
            raise click.UsageError("EPC requires dev and eval score files")
497
        self._title = self._title or 'EPC'
498
499
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
500
        self._eval = True #always eval data with EPC
501
        self._split = False
502
        self._nb_figs = 1
503
        self._far_at = None
504

505
    def compute(self, idx, input_scores, input_names):
506
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
507
508
509
510
511
512
513
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

514
        plot.epc(
515
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
516
            color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
517
            label=self._label(
518
                'curve', dev_file + "_" + eval_file, idx
519
            )
520
521
522
        )

class Hist(PlotBase):
523
    ''' Functional base class for histograms'''
524
525
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
526
        self._nbins = [] if 'n_bins' not in ctx.meta else ctx.meta['n_bins']
527
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
528
529
        self._show_dev = ((not self._eval) if 'show_dev' not in ctx.meta else\
                ctx.meta['show_dev']) or not self._eval
530
        if self._thres is not None and len(self._thres) != self.n_systems:
531
            if len(self._thres) == 1:
532
                self._thres = self._thres * self.n_systems
533
534
535
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
536
                    % self.n_systems
537
                )
538
539
        self._criterion = None if 'criterion' not in ctx.meta else \
        ctx.meta['criterion']
540
        self._y_label = 'Dev. probability density' if self._eval else \
541
                'density' or self._y_label
542
        self._x_label = 'Scores' if not self._eval else ''
543
        self._title_base = self._title or 'Scores'
544
545
        self._end_setup_plot = False

546
    def compute(self, idx, input_scores, input_names):
547
        ''' Draw histograms of negative and positive scores.'''
548
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
549
550
551
        self._get_neg_pos_thres(idx, input_scores, input_names)
        dev_file = input_names[0]
        eval_file = None if len(input_names) != 2 else input_names[1]
552
553

        fig = mpl.figure()
554
        if eval_neg is not None and self._show_dev:
555
            mpl.subplot(2, 1, 1)
556
557
        if self._show_dev:
            self._setup_hist(dev_neg, dev_pos)
558
            mpl.title(self._get_title(idx, dev_file, eval_file))
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
            mpl.ylabel(self._y_label)
            mpl.xlabel(self._x_label)
            if eval_neg is not None and self._show_dev:
                ax = mpl.gca()
                ax.axes.get_xaxis().set_ticklabels([])
            #Setup lines, corresponding axis and legends
            self._lines(threshold, dev_neg, dev_pos)
            if self._eval:
                self._plot_legends()

        if eval_neg is not None:
            if self._show_dev:
                mpl.subplot(2, 1, 2)
            self._setup_hist(
                eval_neg, eval_pos
574
            )
575
576
            if not self._show_dev:
                mpl.title(self._get_title(idx, dev_file, eval_file))
577
            mpl.ylabel('Eval. probability density')
578
            mpl.xlabel(self._x_label)
579
580
581
582
            #Setup lines, corresponding axis and legends
            self._lines(threshold, eval_neg, eval_pos)
            if not self._show_dev:
                self._plot_legends()
583
584

        self._pdf_page.savefig(fig)
585

586
587
588
589
    def _get_title(self, idx, dev_file, eval_file):
        title = self._titles[idx] if self._titles is not None else None
        if title is None:
            title = self._title_base if not self._print_fn else \
590
                    ('%s \n (%s)' % (
591
                        self._title_base,
592
                        str(dev_file) + (" / %s" % str(eval_file) if self._eval else "")
593
                    ))
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        return title

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
        if self._show_dev and self._eval:
            mpl.legend(
                lines, labels, loc='upper center', ncol=6,
                bbox_to_anchor=(0.5, -0.01), fontsize=6
            )
        else:
            mpl.legend(lines, labels,
                       loc='best', fancybox=True, framealpha=0.5)

612
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
613
614
615
616
617
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
        #can have several files for one system
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
618
619
        eval_neg = eval_pos = None
        if self._eval:
620
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
621
622
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

623
        threshold = utils.get_thres(
624
            self._criterion, dev_neg[0], dev_pos[0]
625
        ) if self._thres is None else self._thres[idx]
626
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
627

628
    def _density_hist(self, scores, n, **kwargs):
629
        n, bins, patches = mpl.hist(
630
631
632
            scores, density=True,
            bins='auto' if len(self._nbins) <= n else self._nbins[n],
            **kwargs
633
634
635
636
        )
        return (n, bins, patches)

    def _lines(self, threshold, neg=None, pos=None, **kwargs):
637
        label = 'Threshold' if self._criterion is None else self._criterion.upper()
638
639
640
641
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
642
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
643
644
645
646

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
647
648
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
649
650
        )
        self._density_hist(
651
652
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
653
        )
654