figure.py 26.5 KB
Newer Older
1 2 3 4 5
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
import sys
6
import os.path
7 8 9 10 11
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
12
from .. import (far_threshold, plot, utils, ppndf)
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

LINESTYLES = [
    (0, ()),                    #solid
    (0, (4, 4)),                #dashed
    (0, (1, 5)),                #dotted
    (0, (3, 5, 1, 5)),          #dashdotted
    (0, (3, 5, 1, 5, 1, 5)),    #dashdotdotted
    (0, (5, 1)),                #densely dashed
    (0, (1, 1)),                #densely dotted
    (0, (3, 1, 1, 1)),          #densely dashdotted
    (0, (3, 1, 1, 1, 1, 1)),    #densely dashdotdotted
    (0, (5, 10)),               #loosely dashed
    (0, (3, 10, 1, 10)),        #loosely dashdotted
    (0, (3, 10, 1, 10, 1, 10)), #loosely dashdotdotted
    (0, (1, 10))                #loosely dotted
]

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
41
    def __init__(self, ctx, scores, evaluation, func_load):
42 43 44 45 46 47 48
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
49 50 51 52
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
53 54 55
        func_load : Function that is used to load the input files
        """
        self._scores = scores
56
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
57 58
        self._ctx = ctx
        self.func_load = func_load
59
        self._titles = None if 'titles' not in ctx.meta else ctx.meta['titles']
60 61 62 63 64 65 66 67
        self._eval = evaluation
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
        if self._titles is not None and len(self._titles) != self.n_systems:
68 69
            raise click.BadParameter("Number of titles must be equal to the "
                                     "number of systems")
70 71 72 73 74 75 76 77 78 79 80 81 82

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
        #init matplotlib, log files, ...
        self.init_process()
        #iterates through the different systems and feed `compute`
83 84 85
        #with the dev (and eval) scores of each system
        # Note that more than one dev or eval scores score can be passed to
        # each system
86 87 88 89 90
        for idx in range(self.n_systems):
            input_scores, input_names = self._load_files(
                self._scores[idx:(idx + self._min_arg)]
            )
            self.compute(idx, input_scores, input_names)
91 92 93 94 95 96
        #setup final configuration, plotting properties, ...
        self.end_process()

    #protected functions that need to be overwritten
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
97
        before iterating through the different systems.
98 99 100 101 102
        Should reimplemented in derived classes"""
        pass

    #Main computations are done here in the subclasses
    @abstractmethod
103
    def compute(self, idx, input_scores, input_names):
104
        """Compute metrics or plots from the given scores provided by
105 106 107 108 109 110 111
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
112 113 114 115
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
116 117 118 119 120 121
        """
        pass

    #Things to do after the main iterative computations are done
    @abstractmethod
    def end_process(self):
122
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
123
        after iterating through the different systems.
124
        Should reimplemented in derived classes"""
125 126 127 128
        pass

    #common protected functions

129 130
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
131 132 133

        Returns
        -------
134 135 136 137 138
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
139
        '''
140 141 142 143 144 145
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
146 147 148 149 150 151 152 153 154

class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
155 156
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
157 158 159 160 161 162
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
163 164 165 166 167 168 169 170 171 172
        if self._thres is not None :
            if len(self._thres) == 1:
                self._thres = self._thres * len(self.dev_names)
            elif len(self._thres) != len(self.dev_names):
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
                    % len(self.dev_names)
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
173 174 175 176 177
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

178
    def compute(self, idx, input_scores, input_names):
179
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
180
        given system inputs'''
181 182 183 184 185 186 187
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

188 189
        threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
                if self._thres is None else self._thres[idx]
190
        title = self._titles[idx] if self._titles is not None else None
191
        if self._thres is None:
192 193 194 195
            far_str = ''
            if self._criter == 'far' and self._far is not None:
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
196
                       % (self._criter.upper(), far_str, title or dev_file, threshold),
197 198 199 200
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
201
                       % (dev_file or title, threshold), file=self.log_file)
202 203 204 205 206 207 208 209 210 211 212 213 214


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

215 216 217 218 219 220 221 222 223
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
224 225 226 227 228
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

229
        if self._eval:
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
256 257 258 259 260 261 262 263 264 265 266 267

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
268 269
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
270
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
271
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
272
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
273
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
274 275
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
276 277 278 279 280 281
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
282
        self._print_fn = True if 'show_fn' not in ctx.meta else\
283
        ctx.meta['show_fn']
284 285
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
286 287
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
288
        self._nb_figs = 2 if self._eval and self._split else 1
289
        self._colors = utils.get_colors(self.n_systems)
290
        self._states = ['Development', 'Evaluation']
291 292 293 294 295
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
296 297 298
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True
299
        self._kwargs = {}
300 301 302 303 304 305 306 307 308

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

309
        for i in range(self._nb_figs):
310 311 312 313
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
314
            fig.clear()
315 316

    def end_process(self):
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
        #draw vertical lines
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
339 340 341 342 343
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
344
                if not self._eval:
345
                    title += (" (%s)" % self._states[0])
346
                elif self._split:
347 348 349 350 351
                    title += (" (%s)" % self._states[i])
                mpl.title(title)
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
352
                mpl.legend(loc='best')
353
                self._set_axis()
354 355 356 357 358 359 360 361 362 363 364 365 366
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

    #common protected functions

    def _label(self, base, name, idx):
        if self._titles is not None and len(self._titles) > idx:
            return self._titles[idx]
367
        if self.n_systems > 1:
368 369 370
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

371
    def _set_axis(self):
372 373
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
374

375
class Roc(PlotBase):
376
    ''' Handles the plotting of ROC'''
377 378
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
379 380
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
381
        self._y_label = self._y_label or "1 - False Negative Rate"
382 383 384
        #custom defaults
        if self._axlim is None:
            self._axlim = [1e-4, 1.0, 1e-4, 1.0]
385

386
    def compute(self, idx, input_scores, input_names):
387 388
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
389 390 391 392 393 394 395
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

396
        mpl.figure(1)
397
        if self._eval:
398
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
399 400
            plot.roc_for_far(
                dev_neg, dev_pos,
401
                color=self._colors[idx], linestyle=linestyle,
402
                label=self._label('development', dev_file, idx, **self._kwargs)
403 404 405 406 407 408
            )
            linestyle = '--'
            if self._split:
                mpl.figure(2)
                linestyle = LINESTYLES[idx % 14]

409 410
            plot.roc_for_far(
                eval_neg, eval_pos,
411
                color=self._colors[idx], linestyle=linestyle,
412
                label=self._label('eval', eval_file, idx, **self._kwargs)
413
            )
414
            if self._far_at is not None:
415
                from .. import farfrr
416
                for line in self._far_at:
417 418 419
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fnmr = 1 - eval_fnmr
420 421
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
422
        else:
423 424
            plot.roc_for_far(
                dev_neg, dev_pos,
425
                color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
426
                label=self._label('development', dev_file, idx, **self._kwargs)
427 428 429 430
            )

class Det(PlotBase):
    ''' Handles the plotting of DET '''
431 432
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
433
        self._title = self._title or 'DET'
434 435
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
436 437
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
438 439 440
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
441

442
    def compute(self, idx, input_scores, input_names):
443 444
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
445 446 447 448 449 450 451
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

452
        mpl.figure(1)
453
        if self._eval and eval_neg is not None:
454 455 456
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
457
                linestyle=linestyle,
458
                label=self._label('development', dev_file, idx, **self._kwargs)
459 460 461 462 463
            )
            if self._split:
                mpl.figure(2)
            linestyle = '--' if not self._split else LINESTYLES[idx % 14]
            plot.det(
464
                eval_neg, eval_pos, self._points, color=self._colors[idx],
465
                linestyle=linestyle,
466
                label=self._label('eval', eval_file, idx, **self._kwargs)
467
            )
468 469 470 471 472 473 474 475
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
476 477 478
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
479
                linestyle=LINESTYLES[idx % 14],
480
                label=self._label('development', dev_file, idx, **self._kwargs)
481 482
            )

483
    def _set_axis(self):
484 485 486 487
        if self._axlim is not None and None not in self._axlim:
            plot.det_axis(self._axlim)
        else:
            plot.det_axis([0.01, 99, 0.01, 99])
488 489 490

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
491 492
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
493
        if self._min_arg != 2:
494
            raise click.UsageError("EPC requires dev and eval score files")
495
        self._title = self._title or 'EPC'
496 497
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
498
        self._eval = True #always eval data with EPC
499
        self._split = False
500
        self._nb_figs = 1
501
        self._far_at = None
502

503
    def compute(self, idx, input_scores, input_names):
504
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
505 506 507 508 509 510 511
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

512
        plot.epc(
513
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
514
            color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
515 516 517
            label=self._label(
                'curve', dev_file + "_" + eval_file, idx, **self._kwargs
            )
518 519 520
        )

class Hist(PlotBase):
521
    ''' Functional base class for histograms'''
522 523
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
524 525
        self._nbins = None if 'nbins' not in ctx.meta else ctx.meta['nbins']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
526 527
        self._show_dev = ((not self._eval) if 'show_dev' not in ctx.meta else\
                ctx.meta['show_dev']) or not self._eval
528 529 530 531 532 533 534 535
        if self._thres is not None and len(self._thres) != len(self.dev_names):
            if len(self._thres) == 1:
                self._thres = self._thres * len(self.dev_names)
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
                    % len(self.dev_names)
                )
536
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
537
        self._y_label = 'Dev. probability density' if self._eval else \
538
                'density' or self._y_label
539
        self._x_label = 'Scores' if not self._eval else ''
540
        self._title_base = self._title or 'Scores'
541 542
        self._end_setup_plot = False

543
    def compute(self, idx, input_scores, input_names):
544
        ''' Draw histograms of negative and positive scores.'''
545
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
546 547 548
        self._get_neg_pos_thres(idx, input_scores, input_names)
        dev_file = input_names[0]
        eval_file = None if len(input_names) != 2 else input_names[1]
549 550

        fig = mpl.figure()
551
        if eval_neg is not None and self._show_dev:
552
            mpl.subplot(2, 1, 1)
553 554
        if self._show_dev:
            self._setup_hist(dev_neg, dev_pos)
555
            mpl.title(self._get_title(idx, dev_file, eval_file))
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
            mpl.ylabel(self._y_label)
            mpl.xlabel(self._x_label)
            if eval_neg is not None and self._show_dev:
                ax = mpl.gca()
                ax.axes.get_xaxis().set_ticklabels([])
            #Setup lines, corresponding axis and legends
            self._lines(threshold, dev_neg, dev_pos)
            if self._eval:
                self._plot_legends()

        if eval_neg is not None:
            if self._show_dev:
                mpl.subplot(2, 1, 2)
            self._setup_hist(
                eval_neg, eval_pos
571
            )
572 573
            if not self._show_dev:
                mpl.title(self._get_title(idx, dev_file, eval_file))
574
            mpl.ylabel('Eval. probability density')
575
            mpl.xlabel(self._x_label)
576 577 578 579
            #Setup lines, corresponding axis and legends
            self._lines(threshold, eval_neg, eval_pos)
            if not self._show_dev:
                self._plot_legends()
580 581

        self._pdf_page.savefig(fig)
582

583 584 585 586
    def _get_title(self, idx, dev_file, eval_file):
        title = self._titles[idx] if self._titles is not None else None
        if title is None:
            title = self._title_base if not self._print_fn else \
587
                    ('%s \n (%s)' % (
588
                        self._title_base,
589
                        str(dev_file) + (" / %s" % str(eval_file) if self._eval else "")
590
                    ))
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
        return title

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
        if self._show_dev and self._eval:
            mpl.legend(
                lines, labels, loc='upper center', ncol=6,
                bbox_to_anchor=(0.5, -0.01), fontsize=6
            )
        else:
            mpl.legend(lines, labels,
                       loc='best', fancybox=True, framealpha=0.5)

609 610 611 612 613 614
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        eval_neg = eval_pos = None
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
615 616 617 618 619 620 621 622
        threshold = utils.get_thres(
            self._criter, dev_neg,
            dev_pos
        ) if self._thres is None else self._thres[idx]
        return (dev_neg, dev_pos, eval_neg, eval_pos, threshold)

    def _density_hist(self, scores, **kwargs):
        n, bins, patches = mpl.hist(
623
            scores, density=True, bins=self._nbins, **kwargs
624 625 626 627 628 629 630 631 632
        )
        return (n, bins, patches)

    def _lines(self, threshold, neg=None, pos=None, **kwargs):
        label = 'Threshold' if self._criter is None else self._criter.upper()
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
633
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
634 635 636 637

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
638
            pos, label='Positives', alpha=0.5, color='C0', **self._kwargs
639 640
        )
        self._density_hist(
641
            neg, label='Negatives', alpha=0.5, color='C3', **self._kwargs
642
        )