figure.py 28.5 KB
Newer Older
1 2 3 4 5
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
import sys
6
import os.path
7 8 9 10 11
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
12
from .. import (far_threshold, plot, utils, ppndf)
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

LINESTYLES = [
    (0, ()),                    #solid
    (0, (4, 4)),                #dashed
    (0, (1, 5)),                #dotted
    (0, (3, 5, 1, 5)),          #dashdotted
    (0, (3, 5, 1, 5, 1, 5)),    #dashdotdotted
    (0, (5, 1)),                #densely dashed
    (0, (1, 1)),                #densely dotted
    (0, (3, 1, 1, 1)),          #densely dashdotted
    (0, (3, 1, 1, 1, 1, 1)),    #densely dashdotdotted
    (0, (5, 10)),               #loosely dashed
    (0, (3, 10, 1, 10)),        #loosely dashdotted
    (0, (3, 10, 1, 10, 1, 10)), #loosely dashdotdotted
    (0, (1, 10))                #loosely dotted
]

class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
    __metaclass__ = ABCMeta #for python 2.7 compatibility
41
    def __init__(self, ctx, scores, evaluation, func_load):
42 43 44 45 46 47 48
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
49 50 51 52
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
53 54 55
        func_load : Function that is used to load the input files
        """
        self._scores = scores
56
        self._min_arg = 1 if 'min_arg' not in ctx.meta else ctx.meta['min_arg']
57 58
        self._ctx = ctx
        self.func_load = func_load
59
        self.dev_names, self.eval_names, self.dev_scores, self.eval_scores = \
60
                self._load_files()
61 62 63 64 65
        self.n_sytem = len(self.dev_names[0]) # at least one set of dev scores
        self._titles = None if 'titles' not in ctx.meta else ctx.meta['titles']
        if self._titles is not None and len(self._titles) != self.n_sytem:
            raise click.BadParameter("Number of titles must be equal to the "
                                     "number of systems")
66
        self._eval = evaluation
67 68 69 70 71 72 73 74 75 76 77 78 79

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
        #init matplotlib, log files, ...
        self.init_process()
        #iterates through the different systems and feed `compute`
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
        #with the dev (and eval) scores of each system
        # Note that more than one dev or eval scores score can be passed to
        # each system
        for idx in range(self.n_sytem):
            dev_score = []
            eval_score = []
            dev_file = []
            eval_file = []
            for arg in range(self._min_arg):
                dev_score.append(self.dev_scores[arg][idx])
                dev_file.append(self.dev_names[arg][idx])
                eval_score.append(self.eval_scores[arg][idx] \
                        if self.eval_scores[arg] is not None else None)
                eval_file.append(self.eval_names[arg][idx] \
                        if self.eval_names[arg] is not None else None)
            if self._min_arg == 1: # most of measure only take one arg
                                   # so do not pass a list of one arg
                #does the main computations/plottings here
                self.compute(idx, dev_score[0], dev_file[0], eval_score[0],
                             eval_file[0])
            else:
                #does the main computations/plottings here
                self.compute(idx, dev_score, dev_file, eval_score, eval_file)
103 104 105 106 107 108 109 110 111 112 113 114
        #setup final configuration, plotting properties, ...
        self.end_process()

    #protected functions that need to be overwritten
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
        before iterating through the different sytems.
        Should reimplemented in derived classes"""
        pass

    #Main computations are done here in the subclasses
    @abstractmethod
115
    def compute(self, idx, dev_score, dev_file=None,
116
                eval_score=None, eval_file=None):
117
        """Compute metrics or plots from the given scores provided by
118 119 120 121 122 123 124
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
125 126 127
        dev_score:
            Development scores. Can be a tuple (neg, pos) of
            :py:class:`numpy.ndarray` (e.g.
128
            :py:func:`~bob.measure.script.figure.Roc.compute`) or
129
            a :any:`list` of tuples of :py:class:`numpy.ndarray` (e.g. cmc)
130 131
        dev_file : str
            name of the dev file without extension
132 133
        eval_score:
            eval scores. Can be a tuple (neg, pos) of
134
            :py:class:`numpy.ndarray` (e.g.
135
            :py:func:`~bob.measure.script.figure.Roc.compute`) or
136
            a :any:`list` of tuples of :py:class:`numpy.ndarray` (e.g. cmc)
137 138
        eval_file : str
            name of the eval file without extension
139 140 141 142 143 144
        """
        pass

    #Things to do after the main iterative computations are done
    @abstractmethod
    def end_process(self):
145 146 147
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
        after iterating through the different sytems.
        Should reimplemented in derived classes"""
148 149 150 151 152 153 154 155 156
        pass

    #common protected functions

    def _load_files(self):
        ''' Load the input files and returns

        Returns
        -------
157 158 159 160 161
            dev_scores: :any:`list`: A list that contains, for each required
            dev score file, the output of ``func_load``
            eval_scores: :any:`list`: A list that contains, for each required
            eval score file, the output of ``func_load``
        '''
162 163 164 165 166 167

        def _extract_file_names(filenames):
            if filenames is None:
                return None
            res = []
            for file_path in filenames:
168
                name = os.path.basename(file_path)
169 170 171
                res.append(name.split(".")[0])
            return res

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
        dev_scores = []
        eval_scores = []
        dev_files = []
        eval_files = []
        for arg in range(self._min_arg):
            key = 'dev_scores_%d' % arg
            dev_paths = self._scores if key not in self._ctx.meta else \
                    self._ctx.meta[key]
            key = 'eval_scores_%d' % arg
            eval_paths = None if key not in self._ctx.meta else \
                    self._ctx.meta[key]
            dev_files.append(_extract_file_names(dev_paths))
            eval_files.append(_extract_file_names(eval_paths))
            dev_scores.append(self.func_load(dev_paths))
            eval_scores.append(self.func_load(eval_paths))
        return (dev_files, eval_files, dev_scores, eval_scores)

    def _process_scores(self, dev_score, eval_score):
        '''Process score files and return neg/pos/fta for eval and dev'''
        dev_neg = dev_pos = dev_fta = eval_neg = eval_pos = eval_fta = None
192
        if dev_score[0] is not None:
193
            (dev_neg, dev_pos), dev_fta = utils.get_fta(dev_score)
194 195 196
            if dev_neg is None:
                raise click.UsageError("While loading dev-score file")

197 198 199 200 201
        if self._eval and eval_score is not None and eval_score[0] is not None:
            eval_score, eval_fta = utils.get_fta(eval_score)
            eval_neg, eval_pos = eval_score
            if eval_neg is None:
                raise click.UsageError("While loading eval-score file")
202

203
        return (dev_neg, dev_pos, dev_fta, eval_neg, eval_pos, eval_fta)
204 205 206 207 208 209 210 211 212 213


class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
214 215
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
216 217 218 219 220 221
        self._tablefmt = None if 'tablefmt' not in ctx.meta else\
                ctx.meta['tablefmt']
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
        self._open_mode = None if 'open_mode' not in ctx.meta else\
                ctx.meta['open_mode']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
222 223 224 225 226 227 228 229 230 231
        if self._thres is not None :
            if len(self._thres) == 1:
                self._thres = self._thres * len(self.dev_names)
            elif len(self._thres) != len(self.dev_names):
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
                    % len(self.dev_names)
                )
        self._far = None if 'far_value' not in ctx.meta else \
        ctx.meta['far_value']
232 233 234 235 236
        self._log = None if 'log' not in ctx.meta else ctx.meta['log']
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

237
    def compute(self, idx, dev_score, dev_file=None,
238
                eval_score=None, eval_file=None):
239
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
240
        given system inputs'''
241 242
        dev_neg, dev_pos, dev_fta, eval_neg, eval_pos, eval_fta =\
                self._process_scores(dev_score, eval_score)
243 244
        threshold = utils.get_thres(self._criter, dev_neg, dev_pos, self._far) \
                if self._thres is None else self._thres[idx]
245
        title = self._titles[idx] if self._titles is not None else None
246
        if self._thres is None:
247 248 249 250
            far_str = ''
            if self._criter == 'far' and self._far is not None:
                far_str = str(self._far)
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"\
251
                       % (self._criter.upper(), far_str, title or dev_file, threshold),
252 253 254 255
                       file=self.log_file)
        else:
            click.echo("[Min. criterion: user provider] Threshold on "
                       "Development set `%s`: %e"\
256
                       % (dev_file or title, threshold), file=self.log_file)
257 258 259 260 261 262 263 264 265 266 267 268 269


        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

270 271 272 273 274 275 276 277 278
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
        raws = [['FtA', dev_fta_str],
                ['FMR', dev_fmr_str],
279 280 281 282 283
                ['FNMR', dev_fnmr_str],
                ['FAR', dev_far_str],
                ['FRR', dev_frr_str],
                ['HTER', dev_hter_str]]

284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
        if self._eval and eval_neg is not None:
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
            eval_fnm = int(round(eval_fnmr * eval_nc))  # number of false rejects

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 * eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 * eval_fnmr, eval_fnm, eval_nc)

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
311 312 313 314 315 316 317 318 319 320 321 322

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
323 324
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
325
        self._output = None if 'output' not in ctx.meta else ctx.meta['output']
326
        self._points = 100 if 'points' not in ctx.meta else ctx.meta['points']
327
        self._split = None if 'split' not in ctx.meta else ctx.meta['split']
328
        self._axlim = None if 'axlim' not in ctx.meta else ctx.meta['axlim']
329 330
        self._clayout = None if 'clayout' not in ctx.meta else\
        ctx.meta['clayout']
331 332 333 334 335 336
        self._far_at = None if 'lines_at' not in ctx.meta else\
        ctx.meta['lines_at']
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
337
        self._print_fn = True if 'show_fn' not in ctx.meta else\
338
        ctx.meta['show_fn']
339 340
        self._x_rotation = None if 'x_rotation' not in ctx.meta else \
                ctx.meta['x_rotation']
341 342
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
343
        self._nb_figs = 2 if self._eval and self._split else 1
344 345 346
        self._multi_plots = len(self.dev_scores) > 1
        self._colors = utils.get_colors(len(self.dev_scores))
        self._states = ['Development', 'Evaluation']
347 348 349 350 351
        self._title = None if 'title' not in ctx.meta else ctx.meta['title']
        self._x_label = None if 'x_label' not in ctx.meta else\
        ctx.meta['x_label']
        self._y_label = None if 'y_label' not in ctx.meta else\
        ctx.meta['y_label']
352 353 354
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True
355
        self._kwargs = {}
356 357 358 359 360 361 362 363 364

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
        self._ctx.meta else PdfPages(self._output)

365
        for i in range(self._nb_figs):
366 367 368 369
            fs = None if 'figsize' not in self._ctx.meta else\
                    self._ctx.meta['figsize']
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
370
            fig.clear()
371 372

    def end_process(self):
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
        #draw vertical lines
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
395 396 397 398 399
        #only for plots
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
                title = self._title
400
                if not self._eval:
401
                    title += (" (%s)" % self._states[0])
402
                elif self._split:
403 404 405 406 407
                    title += (" (%s)" % self._states[i])
                mpl.title(title)
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
408
                mpl.legend(loc='best')
409
                self._set_axis()
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

        #do not want to close PDF when running evaluate
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

    #common protected functions

    def _label(self, base, name, idx):
        if self._titles is not None and len(self._titles) > idx:
            return self._titles[idx]
        if self._multi_plots:
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

427
    def _set_axis(self):
428 429
        if self._axlim is not None and None not in self._axlim:
            mpl.axis(self._axlim)
430

431
class Roc(PlotBase):
432
    ''' Handles the plotting of ROC'''
433 434
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
435 436
        self._title = self._title or 'ROC'
        self._x_label = self._x_label or 'False Positive Rate'
437
        self._y_label = self._y_label or "1 - False Negative Rate"
438 439 440
        #custom defaults
        if self._axlim is None:
            self._axlim = [1e-4, 1.0, 1e-4, 1.0]
441

442
    def compute(self, idx, dev_score, dev_file=None,
443
                eval_score=None, eval_file=None):
444 445
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
446 447
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
                self._process_scores(dev_score, eval_score)
448
        mpl.figure(1)
449
        if self._eval:
450
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
451 452
            plot.roc_for_far(
                dev_neg, dev_pos,
453
                color=self._colors[idx], linestyle=linestyle,
454
                label=self._label('development', dev_file, idx, **self._kwargs)
455 456 457 458 459 460
            )
            linestyle = '--'
            if self._split:
                mpl.figure(2)
                linestyle = LINESTYLES[idx % 14]

461 462
            plot.roc_for_far(
                eval_neg, eval_pos,
463
                color=self._colors[idx], linestyle=linestyle,
464
                label=self._label('eval', eval_file, idx, **self._kwargs)
465
            )
466
            if self._far_at is not None:
467
                from .. import farfrr
468
                for line in self._far_at:
469 470 471
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fnmr = 1 - eval_fnmr
472 473
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
474
        else:
475 476
            plot.roc_for_far(
                dev_neg, dev_pos,
477
                color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
478
                label=self._label('development', dev_file, idx, **self._kwargs)
479 480 481 482
            )

class Det(PlotBase):
    ''' Handles the plotting of DET '''
483 484
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
485
        self._title = self._title or 'DET'
486 487
        self._x_label = self._x_label or 'False Positive Rate'
        self._y_label = self._y_label or 'False Negative Rate'
488 489
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
490 491 492
        #custom defaults here
        if self._x_rotation is None:
            self._x_rotation = 50
493

494
    def compute(self, idx, dev_score, dev_file=None,
495
                eval_score=None, eval_file=None):
496 497
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
498 499
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
                self._process_scores(dev_score, eval_score)
500
        mpl.figure(1)
501
        if self._eval and eval_neg is not None:
502 503 504
            linestyle = '-' if not self._split else LINESTYLES[idx % 14]
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
505
                linestyle=linestyle,
506
                label=self._label('development', dev_file, idx, **self._kwargs)
507 508 509 510 511
            )
            if self._split:
                mpl.figure(2)
            linestyle = '--' if not self._split else LINESTYLES[idx % 14]
            plot.det(
512
                eval_neg, eval_pos, self._points, color=self._colors[idx],
513
                linestyle=linestyle,
514
                label=self._label('eval', eval_file, idx, **self._kwargs)
515
            )
516 517 518 519 520 521 522 523
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
524 525 526
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
527
                linestyle=LINESTYLES[idx % 14],
528
                label=self._label('development', dev_file, idx, **self._kwargs)
529 530
            )

531
    def _set_axis(self):
532 533 534 535
        if self._axlim is not None and None not in self._axlim:
            plot.det_axis(self._axlim)
        else:
            plot.det_axis([0.01, 99, 0.01, 99])
536 537 538

class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
539 540 541 542
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
        if 'eval_scores_0' not in self._ctx.meta:
            raise click.UsageError("EPC requires dev and eval score files")
543
        self._title = self._title or 'EPC'
544 545
        self._x_label = self._x_label or r'$\alpha$'
        self._y_label = self._y_label or 'HTER (%)'
546
        self._eval = True #always eval data with EPC
547
        self._split = False
548
        self._nb_figs = 1
549
        self._far_at = None
550

551
    def compute(self, idx, dev_score, dev_file, eval_score, eval_file=None):
552
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
553 554
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ =\
                self._process_scores(dev_score, eval_score)
555
        plot.epc(
556
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
557
            color=self._colors[idx], linestyle=LINESTYLES[idx % 14],
558 559 560
            label=self._label(
                'curve', dev_file + "_" + eval_file, idx, **self._kwargs
            )
561 562 563
        )

class Hist(PlotBase):
564
    ''' Functional base class for histograms'''
565 566
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
567 568
        self._nbins = None if 'nbins' not in ctx.meta else ctx.meta['nbins']
        self._thres = None if 'thres' not in ctx.meta else ctx.meta['thres']
569 570
        self._show_dev = ((not self._eval) if 'show_dev' not in ctx.meta else\
                ctx.meta['show_dev']) or not self._eval
571 572 573 574 575 576 577 578
        if self._thres is not None and len(self._thres) != len(self.dev_names):
            if len(self._thres) == 1:
                self._thres = self._thres * len(self.dev_names)
            else:
                raise click.BadParameter(
                    '#thresholds must be the same as #systems (%d)' \
                    % len(self.dev_names)
                )
579
        self._criter = None if 'criter' not in ctx.meta else ctx.meta['criter']
580
        self._y_label = 'Dev. probability density' if self._eval else \
581
                'density' or self._y_label
582
        self._x_label = 'Scores' if not self._eval else ''
583
        self._title_base = self._title or 'Scores'
584 585
        self._end_setup_plot = False

586
    def compute(self, idx, dev_score, dev_file=None,
587
                eval_score=None, eval_file=None):
588
        ''' Draw histograms of negative and positive scores.'''
589 590
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
        self._get_neg_pos_thres(idx, dev_score, eval_score)
591 592

        fig = mpl.figure()
593
        if eval_neg is not None and self._show_dev:
594
            mpl.subplot(2, 1, 1)
595 596
        if self._show_dev:
            self._setup_hist(dev_neg, dev_pos)
597
            mpl.title(self._get_title(idx, dev_file, eval_file))
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
            mpl.ylabel(self._y_label)
            mpl.xlabel(self._x_label)
            if eval_neg is not None and self._show_dev:
                ax = mpl.gca()
                ax.axes.get_xaxis().set_ticklabels([])
            #Setup lines, corresponding axis and legends
            self._lines(threshold, dev_neg, dev_pos)
            if self._eval:
                self._plot_legends()

        if eval_neg is not None:
            if self._show_dev:
                mpl.subplot(2, 1, 2)
            self._setup_hist(
                eval_neg, eval_pos
613
            )
614 615
            if not self._show_dev:
                mpl.title(self._get_title(idx, dev_file, eval_file))
616
            mpl.ylabel('Eval. probability density')
617
            mpl.xlabel(self._x_label)
618 619 620 621
            #Setup lines, corresponding axis and legends
            self._lines(threshold, eval_neg, eval_pos)
            if not self._show_dev:
                self._plot_legends()
622 623

        self._pdf_page.savefig(fig)
624

625 626 627 628
    def _get_title(self, idx, dev_file, eval_file):
        title = self._titles[idx] if self._titles is not None else None
        if title is None:
            title = self._title_base if not self._print_fn else \
629
                    ('%s \n (%s)' % (
630
                        self._title_base,
631
                        str(dev_file) + (" / %s" % str(eval_file) if self._eval else "")
632
                    ))
633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
        return title

    def _plot_legends(self):
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            li, la = ax.get_legend_handles_labels()
            lines += li
            labels += la
        if self._show_dev and self._eval:
            mpl.legend(
                lines, labels, loc='upper center', ncol=6,
                bbox_to_anchor=(0.5, -0.01), fontsize=6
            )
        else:
            mpl.legend(lines, labels,
                       loc='best', fancybox=True, framealpha=0.5)

    def _get_neg_pos_thres(self, idx, dev_score, eval_score):
        dev_neg, dev_pos, _, eval_neg, eval_pos, _ = self._process_scores(
            dev_score, eval_score
        )
        threshold = utils.get_thres(
            self._criter, dev_neg,
            dev_pos
        ) if self._thres is None else self._thres[idx]
        return (dev_neg, dev_pos, eval_neg, eval_pos, threshold)

    def _density_hist(self, scores, **kwargs):
        n, bins, patches = mpl.hist(
663
            scores, density=True, bins=self._nbins, **kwargs
664 665 666 667 668 669 670 671 672
        )
        return (n, bins, patches)

    def _lines(self, threshold, neg=None, pos=None, **kwargs):
        label = 'Threshold' if self._criter is None else self._criter.upper()
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
673
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
674 675 676 677

    def _setup_hist(self, neg, pos):
        ''' This function can be overwritten in derived classes'''
        self._density_hist(
678
            pos, label='Positives', alpha=0.5, color='C0', **self._kwargs
679 680
        )
        self._density_hist(
681
            neg, label='Negatives', alpha=0.5, color='C3', **self._kwargs
682
        )