figure.py 27.4 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8 9 10 11 12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14 15 16 17 18 19 20 21 22 23 24 25

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
26
    def __init__(self, ctx, scores, evaluation, func_load):
27 28 29 30 31 32 33
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
34 35 36 37
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
38 39 40
        func_load : Function that is used to load the input files
        """
        self._scores = scores
41
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
42 43
        self._ctx = ctx
        self.func_load = func_load
44
        self._legends = None if 'legends' not in ctx.meta else ctx.meta['legends']
45 46 47 48 49 50 51
        self._eval = evaluation
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
52 53
        if self._legends is not None and len(self._legends) != self.n_systems:
            raise click.BadParameter("Number of legends must be equal to the "
54
                                     "number of systems")
55 56 57 58 59 60 61 62 63 64

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
65
        # init matplotlib, log files, ...
66
        self.init_process()
67 68
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
69 70
        # Note that more than one dev or eval scores score can be passed to
        # each system
71
        for idx in range(self.n_systems):
72 73
            # load scores for each system: get the corresponding arrays and 
            # base-name of files
74
            input_scores, input_names = self._load_files(
75 76 77 78 79 80
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
81
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
82 83
            )
            self.compute(idx, input_scores, input_names)
84
        # setup final configuration, plotting properties, ...
85 86
        self.end_process()

87
    # protected functions that need to be overwritten
88 89
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
90
        before iterating through the different systems.
91 92 93
        Should reimplemented in derived classes"""
        pass

94
    # Main computations are done here in the subclasses
95
    @abstractmethod
96
    def compute(self, idx, input_scores, input_names):
97
        """Compute metrics or plots from the given scores provided by
98 99 100 101 102 103 104
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
105 106 107 108
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
109 110 111
        """
        pass

112
    # Things to do after the main iterative computations are done
113 114
    @abstractmethod
    def end_process(self):
115
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
116
        after iterating through the different systems.
117
        Should reimplemented in derived classes"""
118 119
        pass

120
    # common protected functions
121

122 123
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
124 125 126

        Returns
        -------
127 128 129 130 131
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
132
        '''
133 134 135 136 137 138
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
139 140 141 142 143 144 145 146 147

class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
148 149
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
150 151
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
152 153
        self._criterion = None if 'criterion' not in ctx.meta else \
        ctx.meta['criterion']
154 155 156
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
157 158
        if self._thres is not None :
            if len(self._thres) == 1:
159 160
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
161 162
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
163
                    % len(self.n_systems)
164 165 166
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
167 168 169 170 171
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

172
    def compute(self, idx, input_scores, input_names):
173
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
174
        given system inputs'''
175 176 177 178 179 180 181
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

182
        threshold = utils.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
183
                if self._thres is None else self._thres[idx]
184
        title = self._legends[idx] if self._legends is not None else None
185
        if self._thres is None:
186
            far_str = ''
187
            if self._criterion == 'far' and self._far is not None:
188 189
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
190 191 192
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
193 194 195 196
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
197
                       % (dev_file or title, threshold), file=self.log_file)
198 199 200 201 202 203 204 205 206 207 208 209 210


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

211 212 213 214 215 216 217 218 219
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
220 221 222 223 224
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

225
        if self._eval:
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
252 253 254 255 256 257 258 259 260 261 262 263

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
264 265
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
266
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
267
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
268
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
269
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
270 271 272 273 274 275
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
        elif self._axlim is not None:
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
276 277
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
278 279 280 281 282 283
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
284
        self._print_fn = True if 'show_fn' not in ctx.meta else\
285
        ctx.meta['show_fn']
286 287
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
288 289
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
290
        self._nb_figs = 2 if self._eval and self._split else 1
291
        self._colors = utils.get_colors(self.n_systems)
292 293 294
        self._line_linestyles = False if 'line_linestyles' not in ctx.meta else \
                ctx.meta['line_linestyles']
        self._linestyles = utils.get_linestyles(self.n_systems, self._line_linestyles)
295
        self._states = ['Development', 'Evaluation']
296 297 298 299 300
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
301 302 303 304 305 306 307 308 309 310 311 312
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

313
        for i in range(self._nb_figs):
314 315 316 317
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
318
            fig.clear()
319 320

    def end_process(self):
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
        #draw vertical lines
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
343 344 345 346 347
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
348
                if not self._eval:
349
                    title += (" (%s)" % self._states[0])
350
                elif self._split:
351
                    title += (" (%s)" % self._states[i])
Theophile GENTILHOMME's avatar
Theophile GENTILHOMME committed
352
                mpl.title(title if self._title.replace(' ', '') else '')
353 354 355
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
356
                mpl.legend(loc='best')
357
                self._set_axis()
358 359 360 361 362 363 364 365 366 367 368
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

    #common protected functions

    def _label(self, base, name, idx):
369 370
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
371
        if self.n_systems > 1:
372 373 374
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

375
    def _set_axis(self):
376 377
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
378

379
class Roc(PlotBase):
380
    ''' Handles the plotting of ROC'''
381 382
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
383 384
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
385
        self._y_label = self._y_label or "1 - False Negative Rate"
386 387
        #custom defaults
        if self._axlim is None:
388 389 390 391
            self._axlim = [1e-4, 1.0, 0, 1.0]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig)
392

393
    def compute(self, idx, input_scores, input_names):
394 395
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
396 397 398 399 400 401 402
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

403
        mpl.figure(1)
404
        if self._eval:
405 406
            plot.roc_for_far(
                dev_neg, dev_pos,
407 408
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx], linestyle=self._linestyles[idx],
409
                label=self._label('development', dev_file, idx)
410 411 412 413
            )
            if self._split:
                mpl.figure(2)

414
            linestyle = '--' if not self._split else self._linestyles[idx]
415
            plot.roc_for_far(
416 417 418
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx],
419
                label=self._label('eval', eval_file, idx)
420
            )
421
            if self._far_at is not None:
422
                from .. import farfrr
423
                for line in self._far_at:
424 425 426
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fnmr = 1 - eval_fnmr
427 428
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
429
        else:
430 431
            plot.roc_for_far(
                dev_neg, dev_pos,
432 433
                far_values=plot.log_values(self._min_dig or -4),
                color=self._colors[idx], linestyle=self._linestyles[idx],
434
                label=self._label('development', dev_file, idx)
435 436 437 438
            )

class Det(PlotBase):
    ''' Handles the plotting of DET '''
439 440
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
441
        self._title = self._title or 'DET'
442 443
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
444 445
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
446 447 448
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
449

450 451 452 453 454 455
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

456
    def compute(self, idx, input_scores, input_names):
457 458
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
459 460 461 462 463 464 465
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

466
        mpl.figure(1)
467
        if self._eval and eval_neg is not None:
468 469
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
470
                linestyle=self._linestyles[idx],
471
                label=self._label('development', dev_file, idx)
472 473 474
            )
            if self._split:
                mpl.figure(2)
475
            linestyle = '--' if not self._split else self._linestyles[idx]
476
            plot.det(
477
                eval_neg, eval_pos, self._points, color=self._colors[idx],
478
                linestyle=linestyle,
479
                label=self._label('eval', eval_file, idx)
480
            )
481 482 483 484 485 486 487 488
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
489 490 491
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
492
                linestyle=self._linestyles[idx],
493
                label=self._label('development', dev_file, idx)
494 495
            )

496
    def _set_axis(self):
497
        plot.det_axis(self._axlim)
498 499 500

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
501 502
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
503
        if self._min_arg != 2:
504
            raise click.UsageError("EPC requires dev and eval score files")
505
        self._title = self._title or 'EPC'
506 507
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
508
        self._eval = True #always eval data with EPC
509
        self._split = False
510
        self._nb_figs = 1
511
        self._far_at = None
512

513
    def compute(self, idx, input_scores, input_names):
514
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
515 516 517 518 519 520 521
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, _ = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, _ = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

522
        plot.epc(
523
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
524
            color=self._colors[idx], linestyle=self._linestyles[idx],
525
            label=self._label(
526
                'curve', dev_file + "_" + eval_file, idx
527
            )
528 529 530
        )

class Hist(PlotBase):
531
    ''' Functional base class for histograms'''
532 533
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
534
        self._nbins = [] if 'n_bins' not in ctx.meta else ctx.meta['n_bins']
535
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
536 537
        self._show_dev = ((not self._eval) if 'show_dev' not in ctx.meta else\
                ctx.meta['show_dev']) or not self._eval
538
        if self._thres is not None and len(self._thres) != self.n_systems:
539
            if len(self._thres) == 1:
540
                self._thres = self._thres * self.n_systems
541 542 543
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
544
                    % self.n_systems
545
                )
546 547
        self._criterion = None if 'criterion' not in ctx.meta else \
        ctx.meta['criterion']
548 549 550
        self._title_base = 'Scores'
        self._y_label = 'Probability Density'
        self._x_label = 'Scores values'
551 552
        self._end_setup_plot = False

553
    def compute(self, idx, input_scores, input_names):
554
        ''' Draw histograms of negative and positive scores.'''
555
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
556 557 558
        self._get_neg_pos_thres(idx, input_scores, input_names)
        dev_file = input_names[0]
        eval_file = None if len(input_names) != 2 else input_names[1]
559 560

        fig = mpl.figure()
561
        if eval_neg is not None and self._show_dev:
562
            mpl.subplot(2, 1, 1)
563 564
        if self._show_dev:
            self._setup_hist(dev_neg, dev_pos)
565
            mpl.title(self._get_title(idx, dev_file, eval_file))
566
            mpl.ylabel(self._y_label)
567 568
            if not self._eval:
                mpl.xlabel(self._x_label)
569 570 571 572
            if eval_neg is not None and self._show_dev:
                ax = mpl.gca()
                ax.axes.get_xaxis().set_ticklabels([])
            #Setup lines, corresponding axis and legends
573 574 575 576
            label = "%s threshold" % ('' if self._criterion is None else
                                      self._criterion.upper())
            self._lines(threshold, label, dev_neg, dev_pos)
            self._plot_legends()
577 578 579 580 581 582

        if eval_neg is not None:
            if self._show_dev:
                mpl.subplot(2, 1, 2)
            self._setup_hist(
                eval_neg, eval_pos
583
            )
584 585
            if not self._show_dev:
                mpl.title(self._get_title(idx, dev_file, eval_file))
586
            mpl.ylabel('Probability density')
587
            mpl.xlabel(self._x_label)
588
            #Setup lines, corresponding axis and legends
589 590 591
            label = "%s threshold (dev)" % ('' if self._criterion is None else
                                            self._criterion.upper())
            self._lines(threshold, label, eval_neg, eval_pos)
592 593
            if not self._show_dev:
                self._plot_legends()
594 595

        self._pdf_page.savefig(fig)
596

597
    def _get_title(self, idx, dev_file, eval_file):
598
        title = self._legends[idx] if self._legends is not None else None
599 600 601
        title = title or self._title or self._title_base
        title = '' if self._title is not None and not self._title.replace(' ', '') else title
        return title or ''
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
        if self._show_dev and self._eval:
            mpl.legend(
                lines, labels, loc='upper center', ncol=6,
                bbox_to_anchor=(0.5, -0.01), fontsize=6
            )
        else:
            mpl.legend(lines, labels,
                       loc='best', fancybox=True, framealpha=0.5)

619
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
620 621 622 623 624
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
        #can have several files for one system
        dev_neg = [neg_list[x] for x in range(0, length, 2)]
        dev_pos = [pos_list[x] for x in range(0, length, 2)]
625 626
        eval_neg = eval_pos = None
        if self._eval:
627
            eval_neg = [neg_list[x] for x in range(1, length, 2)]
628 629
            eval_pos = [pos_list[x] for x in range(1, length, 2)]

630
        threshold = utils.get_thres(
631
            self._criterion, dev_neg[0], dev_pos[0]
632
        ) if self._thres is None else self._thres[idx]
633
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
634

635
    def _density_hist(self, scores, n, **kwargs):
636
        n, bins, patches = mpl.hist(
637 638 639
            scores, density=True,
            bins='auto' if len(self._nbins) <= n else self._nbins[n],
            **kwargs
640 641 642
        )
        return (n, bins, patches)

643 644
    def _lines(self, threshold, label=None, neg=None, pos=None, **kwargs):
        label = label or 'Threshold'
645 646 647 648
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
649
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
650 651 652 653

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
654 655
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
656 657
        )
        self._density_hist(
658 659
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
660
        )
661