figure.py 28.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
import sys
import ntpath
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
from .. import plot
from .. import utils

LINESTYLES = [
    (0, ()),                    #solid
    (0, (4, 4)),                #dashed
    (0, (1, 5)),                #dotted
    (0, (3, 5, 1, 5)),          #dashdotted
    (0, (3, 5, 1, 5, 1, 5)),    #dashdotdotted
    (0, (5, 1)),                #densely dashed
    (0, (1, 1)),                #densely dotted
    (0, (3, 1, 1, 1)),          #densely dashdotted
    (0, (3, 1, 1, 1, 1, 1)),    #densely dashdotdotted
    (0, (5, 10)),               #loosely dashed
    (0, (3, 10, 1, 10)),        #loosely dashdotted
    (0, (3, 10, 1, 10, 1, 10)), #loosely dashdotdotted
    (0, (1, 10))                #loosely dotted
]

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
42
    def __init__(self, ctx, scores, eval, func_load):
43 44 45 46 47 48 49
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
50 51 52 53
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
54 55 56
        func_load : Function that is used to load the input files
        """
        self._scores = scores
57
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
58 59
        self._ctx = ctx
        self.func_load = func_load
60
        self.dev_names, self.eval_names, self.dev_scores, self.eval_scores = \
61
                self._load_files()
62 63 64 65 66 67
        self.n_sytem = len(self.dev_names[0]) # at least one set of dev scores
        self._titles = None if 'titles' not in ctx.meta else ctx.meta['titles']
        if self._titles is not None and len(self._titles) != self.n_sytem:
            raise click.BadParameter("Number of titles must be equal to the "
                                     "number of systems")
        self._eval = eval
68 69 70 71 72 73 74 75 76 77 78 79 80

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
        #init matplotlib, log files, ...
        self.init_process()
        #iterates through the different systems and feed `compute`
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        #with the dev (and eval) scores of each system
        # Note that more than one dev or eval scores score can be passed to
        # each system
        for idx in range(self.n_sytem):
            dev_score = []
            eval_score = []
            dev_file = []
            eval_file = []
            for arg in range(self._min_arg):
                dev_score.append(self.dev_scores[arg][idx])
                dev_file.append(self.dev_names[arg][idx])
                eval_score.append(self.eval_scores[arg][idx] \
                        if self.eval_scores[arg] is not None else None)
                eval_file.append(self.eval_names[arg][idx] \
                        if self.eval_names[arg] is not None else None)
            if self._min_arg == 1: # most of measure only take one arg
                                   # so do not pass a list of one arg
                #does the main computations/plottings here
                self.compute(idx, dev_score[0], dev_file[0], eval_score[0],
                             eval_file[0])
            else:
                #does the main computations/plottings here
                self.compute(idx, dev_score, dev_file, eval_score, eval_file)
104 105 106 107 108 109 110 111 112 113 114 115
        #setup final configuration, plotting properties, ...
        self.end_process()

    #protected functions that need to be overwritten
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
        before iterating through the different sytems.
        Should reimplemented in derived classes"""
        pass

    #Main computations are done here in the subclasses
    @abstractmethod
116
    def compute(self, idx, dev_score, dev_file=None,
117
                eval_score=None, eval_file=None):
118
        """Compute metrics or plots from the given scores provided by
119 120 121 122 123 124 125
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
126 127 128
        dev_score:
            Development scores. Can be a tuple (neg, pos) of
            :py:class:`numpy.ndarray` (e.g.
129
            :py:func:`~bob.measure.script.figure.Roc.compute`) or
130
            a :any:`list` of tuples of :py:class:`numpy.ndarray` (e.g. cmc)
131 132
        dev_file : str
            name of the dev file without extension
133 134
        eval_score:
            eval scores. Can be a tuple (neg, pos) of
135
            :py:class:`numpy.ndarray` (e.g.
136
            :py:func:`~bob.measure.script.figure.Roc.compute`) or
137
            a :any:`list` of tuples of :py:class:`numpy.ndarray` (e.g. cmc)
138 139
        eval_file : str
            name of the eval file without extension
140 141 142 143 144 145
        """
        pass

    #Things to do after the main iterative computations are done
    @abstractmethod
    def end_process(self):
146 147 148
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
        after iterating through the different sytems.
        Should reimplemented in derived classes"""
149 150 151 152 153 154 155 156 157
        pass

    #common protected functions

    def _load_files(self):
        ''' Load the input files and returns

        Returns
        -------
158 159 160 161 162
            dev_scores: :any:`list`: A list that contains, for each required
            dev score file, the output of ``func_load``
            eval_scores: :any:`list`: A list that contains, for each required
            eval score file, the output of ``func_load``
        '''
163 164 165 166 167 168 169 170 171 172

        def _extract_file_names(filenames):
            if filenames is None:
                return None
            res = []
            for file_path in filenames:
                _, name = ntpath.split(file_path)
                res.append(name.split(".")[0])
            return res

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
        dev_scores = []
        eval_scores = []
        dev_files = []
        eval_files = []
        for arg in range(self._min_arg):
            key = 'dev_scores_%d' % arg
            dev_paths = self._scores if key not in self._ctx.meta else \
                    self._ctx.meta[key]
            key = 'eval_scores_%d' % arg
            eval_paths = None if key not in self._ctx.meta else \
                    self._ctx.meta[key]
            dev_files.append(_extract_file_names(dev_paths))
            eval_files.append(_extract_file_names(eval_paths))
            dev_scores.append(self.func_load(dev_paths))
            eval_scores.append(self.func_load(eval_paths))
        return (dev_files, eval_files, dev_scores, eval_scores)

    def _process_scores(self, dev_score, eval_score):
        '''Process score files and return neg/pos/fta for eval and dev'''
        dev_neg = dev_pos = dev_fta = eval_neg = eval_pos = eval_fta = None
193
        if dev_score[0] is not None:
194
            (dev_neg, dev_pos), dev_fta = utils.get_fta(dev_score)
195 196 197
            if dev_neg is None:
                raise click.UsageError("While loading dev-score file")

198 199 200 201 202
        if self._eval and eval_score is not None and eval_score[0] is not None:
            eval_score, eval_fta = utils.get_fta(eval_score)
            eval_neg, eval_pos = eval_score
            if eval_neg is None:
                raise click.UsageError("While loading eval-score file")
203

204
        return (dev_neg, dev_pos, dev_fta, eval_neg, eval_pos, eval_fta)
205 206 207 208 209 210 211 212 213 214


class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
215 216
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
217 218 219 220 221 222
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
223 224 225 226 227 228 229 230 231 232
        if self._thres is not None :
            if len(self._thres) == 1:
                self._thres = self._thres * len(self.dev_names)
            elif len(self._thres) != len(self.dev_names):
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
                    % len(self.dev_names)
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
233 234 235 236 237
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

238
    def compute(self, idx, dev_score, dev_file=None,
239
                eval_score=None, eval_file=None):
240
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
241
        given system inputs'''
242 243
        dev_neg, dev_pos, dev_fta, eval_neg, eval_pos, eval_fta =\
                self._process_scores(dev_score, eval_score)
244 245
        threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
                if self._thres is None else self._thres[idx]
246
        title = self._titles[idx] if self._titles is not None else None
247
        if self._thres is None:
248 249 250 251
            far_str = ''
            if self._criter == 'far' and self._far is not None:
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
252
                       % (self._criter.upper(), far_str, title or dev_file, threshold),
253 254 255 256
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
257
                       % (dev_file or title, threshold), file=self.log_file)
258 259 260 261 262 263 264 265 266 267 268 269 270


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

271 272 273 274 275 276 277 278 279
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
280 281 282 283 284
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
        if self._eval and eval_neg is not None:
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
312 313 314 315 316 317 318 319 320 321 322 323

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
324 325
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
326
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
327
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
328
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
329
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
330 331
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
332
        self._print_fn = True if 'show_fn' not in ctx.meta else\
333
        ctx.meta['show_fn']
334 335
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
336 337
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
338
        self._nb_figs = 2 if self._eval and self._split else 1
339 340 341
        self._multi_plots = len(self.dev_scores) > 1
        self._colors = utils.get_colors(len(self.dev_scores))
        self._states = ['Development', 'Evaluation']
342 343 344 345 346
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
347 348 349
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True
350
        self._kwargs = {}
351 352 353 354 355 356 357 358 359

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

360
        for i in range(self._nb_figs):
361 362 363 364
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
365
            fig.clear()
366 367

    def end_process(self):
368
        ''' Set title, legend, axis labels, grid colors, save figures and
369 370 371 372 373 374
        close pdf is needed '''
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
375
                if not self._eval:
376
                    title += (" (%s)" % self._states[0])
377
                elif self._split:
378 379 380 381 382
                    title += (" (%s)" % self._states[i])
                mpl.title(title)
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
383
                mpl.legend(loc='best')
384
                self._set_axis()
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()


    #common protected functions

    def _label(self, base, name, idx):
        if self._titles is not None and len(self._titles) > idx:
            return self._titles[idx]
        if self._multi_plots:
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

403
    def _set_axis(self):
404 405
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
406

407
class Roc(PlotBase):
408
    ''' Handles the plotting of ROC'''
409 410
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
411
        self._semilogx = True if 'semilogx' not in ctx.meta else\
412
        ctx.meta['semilogx']
413 414 415 416 417 418 419
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or (
            "1 - False Negative Rate" if self._semilogx else "False Negative Rate"
        )
420 421 422
        #custom defaults
        if self._axlim is None:
            self._axlim = [1e-4, 1.0, 1e-4, 1.0]
423 424
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
425

426
    def compute(self, idx, dev_score, dev_file=None,
427
                eval_score=None, eval_file=None):
428 429
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
430 431
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
                self._process_scores(dev_score, eval_score)
432
        mpl.figure(1)
433
        if self._eval:
434 435 436 437
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
            plot.roc(
                dev_neg, dev_pos, self._points, self._semilogx,
                color=self._colors[idx], linestyle=linestyle,
438
                label=self._label('development', dev_file, idx, **self._kwargs)
439 440 441 442 443 444 445
            )
            linestyle = '--'
            if self._split:
                mpl.figure(2)
                linestyle = LINESTYLES[idx % 14]

            plot.roc(
446
                eval_neg, eval_pos, self._points, self._semilogx,
447
                color=self._colors[idx], linestyle=linestyle,
448
                label=self._label('eval', eval_file, idx, **self._kwargs)
449
            )
450
            if self._far_at is not None:
451
                from .. import farfrr
452 453 454 455 456 457
                for line in self._far_at:
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, line)
                    if self._semilogx:
                        eval_fnmr = 1 - eval_fnmr
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
458 459 460 461
        else:
            plot.roc(
                dev_neg, dev_pos, self._points, self._semilogx,
                color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
462
                label=self._label('development', dev_file, idx, **self._kwargs)
463 464 465 466
            )

    def end_process(self):
        ''' Draw vertical line on the dev plot at the given fmr and print the
467
        corresponding points on the eval plot for all the systems '''
468
        #draw vertical lines
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
        if self._far_at is not None:
            for line in self._far_at:
                mpl.figure(1)
                mpl.plot([line, line], [0., 1.], "--", color='black')
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
485 486 487 488
        super(Roc, self).end_process()

class Det(PlotBase):
    ''' Handles the plotting of DET '''
489 490
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
491 492 493
        self._title = self._title or 'DET' 
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
494 495 496
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
497

498
    def compute(self, idx, dev_score, dev_file=None,
499
                eval_score=None, eval_file=None):
500 501
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
502 503
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
                self._process_scores(dev_score, eval_score)
504
        mpl.figure(1)
505
        if self._eval and eval_neg is not None:
506 507 508
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
509
                linestyle=linestyle,
510
                label=self._label('development', dev_file, idx, **self._kwargs)
511 512 513 514 515
            )
            if self._split:
                mpl.figure(2)
            linestyle = '--' if not self._split else LINESTYLES[idx % 14]
            plot.det(
516
                eval_neg, eval_pos, self._points, color=self._colors[idx],
517
                linestyle=linestyle,
518
                label=self._label('eval', eval_file, idx, **self._kwargs)
519 520 521 522
            )
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
523
                linestyle=LINESTYLES[idx % 14],
524
                label=self._label('development', dev_file, idx, **self._kwargs)
525 526
            )

527
    def _set_axis(self):
528 529 530 531
        if self._axlim is not None and None not in self._axlim:
            plot.det_axis(self._axlim)
        else:
            plot.det_axis([0.01, 99, 0.01, 99])
532 533 534

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
535 536 537 538
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
        if 'eval_scores_0' not in self._ctx.meta:
            raise click.UsageError("EPC requires dev and eval score files")
539
        self._title = self._title or 'EPC'
540 541
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
542
        self._eval = True #always eval data with EPC
543
        self._split = False
544
        self._nb_figs = 1
545

546
    def compute(self, idx, dev_score, dev_file, eval_score, eval_file=None):
547
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
548 549
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
                self._process_scores(dev_score, eval_score)
550
        plot.epc(
551
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
552
            color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
553 554 555
            label=self._label(
                'curve', dev_file + "_" + eval_file, idx, **self._kwargs
            )
556 557 558
        )

class Hist(PlotBase):
559
    ''' Functional base class for histograms'''
560 561
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
562 563
        self._nbins = None if 'nbins' not in ctx.meta else ctx.meta['nbins']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
564 565
        self._show_dev = ((not self._eval) if 'show_dev' not in ctx.meta else\
                ctx.meta['show_dev']) or not self._eval
566 567 568 569 570 571 572 573
        if self._thres is not None and len(self._thres) != len(self.dev_names):
            if len(self._thres) == 1:
                self._thres = self._thres * len(self.dev_names)
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
                    % len(self.dev_names)
                )
574
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
575
        self._y_label = 'Dev. probability density' if self._eval else \
576
                'density' or self._y_label
577
        self._x_label = 'Scores' if not self._eval else ''
578
        self._title_base = self._title or 'Scores'
579 580
        self._end_setup_plot = False

581
    def compute(self, idx, dev_score, dev_file=None,
582
                eval_score=None, eval_file=None):
583
        ''' Draw histograms of negative and positive scores.'''
584 585
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
        self._get_neg_pos_thres(idx, dev_score, eval_score)
586 587

        fig = mpl.figure()
588
        if eval_neg is not None and self._show_dev:
589
            mpl.subplot(2, 1, 1)
590 591
        if self._show_dev:
            self._setup_hist(dev_neg, dev_pos)
592
            mpl.title(self._get_title(idx, dev_file, eval_file))
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
            mpl.ylabel(self._y_label)
            mpl.xlabel(self._x_label)
            if eval_neg is not None and self._show_dev:
                ax = mpl.gca()
                ax.axes.get_xaxis().set_ticklabels([])
            #Setup lines, corresponding axis and legends
            self._lines(threshold, dev_neg, dev_pos)
            if self._eval:
                self._plot_legends()

        if eval_neg is not None:
            if self._show_dev:
                mpl.subplot(2, 1, 2)
            self._setup_hist(
                eval_neg, eval_pos
608
            )
609 610
            if not self._show_dev:
                mpl.title(self._get_title(idx, dev_file, eval_file))
611
            mpl.ylabel('Eval. probability density')
612
            mpl.xlabel(self._x_label)
613 614 615 616
            #Setup lines, corresponding axis and legends
            self._lines(threshold, eval_neg, eval_pos)
            if not self._show_dev:
                self._plot_legends()
617 618

        self._pdf_page.savefig(fig)
619

620 621 622 623
    def _get_title(self, idx, dev_file, eval_file):
        title = self._titles[idx] if self._titles is not None else None
        if title is None:
            title = self._title_base if not self._print_fn else \
624
                    ('%s \n (%s)' % (
625
                        self._title_base,
626
                        str(dev_file) + (" / %s" % str(eval_file) if self._eval else "")
627
                    ))
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
        return title

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
        if self._show_dev and self._eval:
            mpl.legend(
                lines, labels, loc='upper center', ncol=6,
                bbox_to_anchor=(0.5, -0.01), fontsize=6
            )
        else:
            mpl.legend(lines, labels,
                       loc='best', fancybox=True, framealpha=0.5)

    def _get_neg_pos_thres(self, idx, dev_score, eval_score):
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ = self._process_scores(
            dev_score, eval_score
        )
        threshold = utils.get_thres(
            self._criter, dev_neg,
            dev_pos
        ) if self._thres is None else self._thres[idx]
        return (dev_neg, dev_pos, eval_neg, eval_pos, threshold)

    def _density_hist(self, scores, **kwargs):
        n, bins, patches = mpl.hist(
658
            scores, density=True, bins=self._nbins, **kwargs
659 660 661 662 663 664 665 666 667
        )
        return (n, bins, patches)

    def _lines(self, threshold, neg=None, pos=None, **kwargs):
        label = 'Threshold' if self._criter is None else self._criter.upper()
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
668
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
669 670 671 672 673 674 675 676 677

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
            pos, label='Positives', alpha=0.5, color='blue', **self._kwargs
        )
        self._density_hist(
            neg, label='Negatives', alpha=0.5, color='red', **self._kwargs
        )