figure.py 26.5 KB
Newer Older
1
2
3
4
5
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
import sys
6
import os.path
7
8
9
10
11
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
12
from .. import (far_threshold, plot, utils, ppndf)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

LINESTYLES = [
    (0, ()),                    #solid
    (0, (4, 4)),                #dashed
    (0, (1, 5)),                #dotted
    (0, (3, 5, 1, 5)),          #dashdotted
    (0, (3, 5, 1, 5, 1, 5)),    #dashdotdotted
    (0, (5, 1)),                #densely dashed
    (0, (1, 1)),                #densely dotted
    (0, (3, 1, 1, 1)),          #densely dashdotted
    (0, (3, 1, 1, 1, 1, 1)),    #densely dashdotdotted
    (0, (5, 10)),               #loosely dashed
    (0, (3, 10, 1, 10)),        #loosely dashdotted
    (0, (3, 10, 1, 10, 1, 10)), #loosely dashdotdotted
    (0, (1, 10))                #loosely dotted
]

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
41
    def __init__(self, ctx, scores, evaluation, func_load):
42
43
44
45
46
47
48
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
49
50
51
52
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
53
54
55
        func_load : Function that is used to load the input files
        """
        self._scores = scores
56
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
57
58
        self._ctx = ctx
        self.func_load = func_load
59
        self._titles = None if 'titles' not in ctx.meta else ctx.meta['titles']
60
61
62
63
64
65
66
67
        self._eval = evaluation
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
        if self._titles is not None and len(self._titles) != self.n_systems:
68
69
            raise click.BadParameter("Number of titles must be equal to the "
                                     "number of systems")
70
71
72
73
74
75
76
77
78
79
80
81
82

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
        #init matplotlib, log files, ...
        self.init_process()
        #iterates through the different systems and feed `compute`
83
84
85
        #with the dev (and eval) scores of each system
        # Note that more than one dev or eval scores score can be passed to
        # each system
86
87
88
89
90
        for idx in range(self.n_systems):
            input_scores, input_names = self._load_files(
                self._scores[idx:(idx + self._min_arg)]
            )
            self.compute(idx, input_scores, input_names)
91
92
93
94
95
96
        #setup final configuration, plotting properties, ...
        self.end_process()

    #protected functions that need to be overwritten
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
97
        before iterating through the different systems.
98
99
100
101
102
        Should reimplemented in derived classes"""
        pass

    #Main computations are done here in the subclasses
    @abstractmethod
103
    def compute(self, idx, input_scores, input_names):
104
        """Compute metrics or plots from the given scores provided by
105
106
107
108
109
110
111
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
112
113
114
115
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
116
117
118
119
120
121
        """
        pass

    #Things to do after the main iterative computations are done
    @abstractmethod
    def end_process(self):
122
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
123
        after iterating through the different systems.
124
        Should reimplemented in derived classes"""
125
126
127
128
        pass

    #common protected functions

129
130
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
131
132
133

        Returns
        -------
134
135
136
137
138
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
139
        '''
140
141
142
143
144
145
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
146
147
148
149
150
151
152
153
154

class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
155
156
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
157
158
159
160
161
162
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
163
164
        if self._thres is not None :
            if len(self._thres) == 1:
165
166
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
167
168
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
169
                    % len(self.n_systems)
170
171
172
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
173
174
175
176
177
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

178
    def compute(self, idx, input_scores, input_names):
179
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
180
        given system inputs'''
181
182
183
184
185
186
187
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

188
189
        threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
                if self._thres is None else self._thres[idx]
190
        title = self._titles[idx] if self._titles is not None else None
191
        if self._thres is None:
192
193
194
195
            far_str = ''
            if self._criter == 'far' and self._far is not None:
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
196
                       % (self._criter.upper(), far_str, title or dev_file, threshold),
197
198
199
200
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
201
                       % (dev_file or title, threshold), file=self.log_file)
202
203
204
205
206
207
208
209
210
211
212
213
214


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

215
216
217
218
219
220
221
222
223
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
224
225
226
227
228
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

229
        if self._eval:
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
256
257
258
259
260
261
262
263
264
265
266
267

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
268
269
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
270
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
271
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
272
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
273
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
274
275
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
276
277
278
279
280
281
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
282
        self._print_fn = True if 'show_fn' not in ctx.meta else\
283
        ctx.meta['show_fn']
284
285
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
286
287
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
288
        self._nb_figs = 2 if self._eval and self._split else 1
289
        self._colors = utils.get_colors(self.n_systems)
290
        self._states = ['Development', 'Evaluation']
291
292
293
294
295
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
296
297
298
299
300
301
302
303
304
305
306
307
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

308
        for i in range(self._nb_figs):
309
310
311
312
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
313
            fig.clear()
314
315

    def end_process(self):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
        #draw vertical lines
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
338
339
340
341
342
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
343
                if not self._eval:
344
                    title += (" (%s)" % self._states[0])
345
                elif self._split:
346
347
348
349
350
                    title += (" (%s)" % self._states[i])
                mpl.title(title)
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
351
                mpl.legend(loc='best')
352
                self._set_axis()
353
354
355
356
357
358
359
360
361
362
363
364
365
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

    #common protected functions

    def _label(self, base, name, idx):
        if self._titles is not None and len(self._titles) > idx:
            return self._titles[idx]
366
        if self.n_systems > 1:
367
368
369
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

370
    def _set_axis(self):
371
372
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
373

374
class Roc(PlotBase):
375
    ''' Handles the plotting of ROC'''
376
377
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
378
379
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
380
        self._y_label = self._y_label or "1 - False Negative Rate"
381
382
383
        #custom defaults
        if self._axlim is None:
            self._axlim = [1e-4, 1.0, 1e-4, 1.0]
384

385
    def compute(self, idx, input_scores, input_names):
386
387
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
388
389
390
391
392
393
394
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

395
        mpl.figure(1)
396
        if self._eval:
397
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
398
399
            plot.roc_for_far(
                dev_neg, dev_pos,
400
                color=self._colors[idx], linestyle=linestyle,
401
                label=self._label('development', dev_file, idx)
402
403
404
405
406
407
            )
            linestyle = '--'
            if self._split:
                mpl.figure(2)
                linestyle = LINESTYLES[idx % 14]

408
409
            plot.roc_for_far(
                eval_neg, eval_pos,
410
                color=self._colors[idx], linestyle=linestyle,
411
                label=self._label('eval', eval_file, idx)
412
            )
413
            if self._far_at is not None:
414
                from .. import farfrr
415
                for line in self._far_at:
416
417
418
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fnmr = 1 - eval_fnmr
419
420
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
421
        else:
422
423
            plot.roc_for_far(
                dev_neg, dev_pos,
424
                color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
425
                label=self._label('development', dev_file, idx)
426
427
428
429
            )

class Det(PlotBase):
    ''' Handles the plotting of DET '''
430
431
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
432
        self._title = self._title or 'DET'
433
434
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
435
436
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
437
438
439
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
440

441
    def compute(self, idx, input_scores, input_names):
442
443
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
444
445
446
447
448
449
450
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

451
        mpl.figure(1)
452
        if self._eval and eval_neg is not None:
453
454
455
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
456
                linestyle=linestyle,
457
                label=self._label('development', dev_file, idx)
458
459
460
461
462
            )
            if self._split:
                mpl.figure(2)
            linestyle = '--' if not self._split else LINESTYLES[idx % 14]
            plot.det(
463
                eval_neg, eval_pos, self._points, color=self._colors[idx],
464
                linestyle=linestyle,
465
                label=self._label('eval', eval_file, idx)
466
            )
467
468
469
470
471
472
473
474
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
475
476
477
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
478
                linestyle=LINESTYLES[idx % 14],
479
                label=self._label('development', dev_file, idx)
480
481
            )

482
    def _set_axis(self):
483
484
485
486
        if self._axlim is not None and None not in self._axlim:
            plot.det_axis(self._axlim)
        else:
            plot.det_axis([0.01, 99, 0.01, 99])
487
488
489

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
490
491
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
492
        if self._min_arg != 2:
493
            raise click.UsageError("EPC requires dev and eval score files")
494
        self._title = self._title or 'EPC'
495
496
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
497
        self._eval = True #always eval data with EPC
498
        self._split = False
499
        self._nb_figs = 1
500
        self._far_at = None
501

502
    def compute(self, idx, input_scores, input_names):
503
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
504
505
506
507
508
509
510
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

511
        plot.epc(
512
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
513
            color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
514
            label=self._label(
515
                'curve', dev_file + "_" + eval_file, idx
516
            )
517
518
519
        )

class Hist(PlotBase):
520
    ''' Functional base class for histograms'''
521
522
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
523
        self._nbins = None if 'n_bins' not in ctx.meta else ctx.meta['n_bins']
524
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
525
526
        self._show_dev = ((not self._eval) if 'show_dev' not in ctx.meta else\
                ctx.meta['show_dev']) or not self._eval
527
        if self._thres is not None and len(self._thres) != self.n_systems:
528
            if len(self._thres) == 1:
529
                self._thres = self._thres * self.n_systems
530
531
532
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
533
                    % self.n_systems
534
                )
535
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
536
        self._y_label = 'Dev. probability density' if self._eval else \
537
                'density' or self._y_label
538
        self._x_label = 'Scores' if not self._eval else ''
539
        self._title_base = self._title or 'Scores'
540
541
        self._end_setup_plot = False

542
    def compute(self, idx, input_scores, input_names):
543
        ''' Draw histograms of negative and positive scores.'''
544
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
545
546
547
        self._get_neg_pos_thres(idx, input_scores, input_names)
        dev_file = input_names[0]
        eval_file = None if len(input_names) != 2 else input_names[1]
548
549

        fig = mpl.figure()
550
        if eval_neg is not None and self._show_dev:
551
            mpl.subplot(2, 1, 1)
552
553
        if self._show_dev:
            self._setup_hist(dev_neg, dev_pos)
554
            mpl.title(self._get_title(idx, dev_file, eval_file))
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
            mpl.ylabel(self._y_label)
            mpl.xlabel(self._x_label)
            if eval_neg is not None and self._show_dev:
                ax = mpl.gca()
                ax.axes.get_xaxis().set_ticklabels([])
            #Setup lines, corresponding axis and legends
            self._lines(threshold, dev_neg, dev_pos)
            if self._eval:
                self._plot_legends()

        if eval_neg is not None:
            if self._show_dev:
                mpl.subplot(2, 1, 2)
            self._setup_hist(
                eval_neg, eval_pos
570
            )
571
572
            if not self._show_dev:
                mpl.title(self._get_title(idx, dev_file, eval_file))
573
            mpl.ylabel('Eval. probability density')
574
            mpl.xlabel(self._x_label)
575
576
577
578
            #Setup lines, corresponding axis and legends
            self._lines(threshold, eval_neg, eval_pos)
            if not self._show_dev:
                self._plot_legends()
579
580

        self._pdf_page.savefig(fig)
581

582
583
584
585
    def _get_title(self, idx, dev_file, eval_file):
        title = self._titles[idx] if self._titles is not None else None
        if title is None:
            title = self._title_base if not self._print_fn else \
586
                    ('%s \n (%s)' % (
587
                        self._title_base,
588
                        str(dev_file) + (" / %s" % str(eval_file) if self._eval else "")
589
                    ))
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        return title

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
        if self._show_dev and self._eval:
            mpl.legend(
                lines, labels, loc='upper center', ncol=6,
                bbox_to_anchor=(0.5, -0.01), fontsize=6
            )
        else:
            mpl.legend(lines, labels,
                       loc='best', fancybox=True, framealpha=0.5)

608
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
609
610
611
612
613
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
        #can have several files for one system
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
614
615
        eval_neg = eval_pos = None
        if self._eval:
616
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
617
618
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

619
        threshold = utils.get_thres(
620
            self._criter, dev_neg[0], dev_pos[0]
621
        ) if self._thres is None else self._thres[idx]
622
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
623
624
625

    def _density_hist(self, scores, **kwargs):
        n, bins, patches = mpl.hist(
626
            scores, density=True, bins=self._nbins, **kwargs
627
628
629
630
631
632
633
634
635
        )
        return (n, bins, patches)

    def _lines(self, threshold, neg=None, pos=None, **kwargs):
        label = 'Threshold' if self._criter is None else self._criter.upper()
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
636
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
637
638
639
640

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
641
            pos[0], label='Positives', alpha=0.5, color='C0'
642
643
        )
        self._density_hist(
644
            neg[0], label='Negatives', alpha=0.5, color='C3'
645
        )