figure.py 30 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8 9 10 11 12
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
13
from .. import (far_threshold, plot, utils, ppndf)
14

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
15

16 17 18 19 20 21 22 23 24 25 26 27
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


28 29 30 31 32 33 34 35 36 37
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
38 39
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

40
    def __init__(self, ctx, scores, evaluation, func_load):
41 42 43 44 45 46 47
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
48 49 50 51
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
52 53 54 55 56
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
57
        self._legends = ctx.meta.get('legends')
58
        self._eval = evaluation
59
        self._min_arg = ctx.meta.get('min_arg', 1)
60 61 62 63 64
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
65 66
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
67
                                     "number of systems")
68 69 70 71 72 73 74 75 76 77

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
78
        # init matplotlib, log files, ...
79
        self.init_process()
80 81
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
82 83
        # Note that more than one dev or eval scores score can be passed to
        # each system
84
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
85
            # load scores for each system: get the corresponding arrays and
86
            # base-name of files
87
            input_scores, input_names = self._load_files(
88 89 90 91 92 93
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
94
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
95 96
            )
            self.compute(idx, input_scores, input_names)
97
        # setup final configuration, plotting properties, ...
98 99
        self.end_process()

100
    # protected functions that need to be overwritten
101 102
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
103
        before iterating through the different systems.
104 105 106
        Should reimplemented in derived classes"""
        pass

107
    # Main computations are done here in the subclasses
108
    @abstractmethod
109
    def compute(self, idx, input_scores, input_names):
110
        """Compute metrics or plots from the given scores provided by
111 112 113 114 115 116 117
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
118 119 120 121
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
122 123
        """
        pass
124 125 126 127 128 129 130
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

131

132
    # Things to do after the main iterative computations are done
133 134
    @abstractmethod
    def end_process(self):
135
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
136
        after iterating through the different systems.
137
        Should reimplemented in derived classes"""
138 139
        pass

140
    # common protected functions
141

142 143
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
144 145 146

        Returns
        -------
147 148 149 150 151
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
152
        '''
153 154 155 156 157 158
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
159

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
160

161 162 163 164 165 166 167 168
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
169

170 171
    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('FtA', 'FMR', 'FNMR', 'FAR', 'FRR', 'HTER')):
172
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
173
        self.names = names
174 175 176 177
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
178
        if self._thres is not None:
179
            if len(self._thres) == 1:
180 181
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
182
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
183
                    '#thresholds must be the same as #systems (%d)'
184
                    % len(self.n_systems)
185
                )
186 187
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
188 189 190 191
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

192 193 194
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

195
    def compute(self, idx, input_scores, input_names):
196
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
197
        given system inputs'''
198 199 200 201 202 203 204
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

205
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
206
            if self._thres is None else self._thres[idx]
207
        title = self._legends[idx] if self._legends is not None else None
208
        if self._thres is None:
209
            far_str = ''
210
            if self._criterion == 'far' and self._far is not None:
211
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
212
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
213 214 215
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
216 217
                       file=self.log_file)
        else:
218
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
219
                       "Development set `%s`: %e"
220
                       % (dev_file or title, threshold), file=self.log_file)
221 222 223 224 225 226 227 228 229 230 231 232

        from .. import farfrr
        dev_fmr, dev_fnmr = farfrr(dev_neg, dev_pos, threshold)
        dev_far = dev_fmr * (1 - dev_fta)
        dev_frr = dev_fta + dev_fnmr * (1 - dev_fta)
        dev_hter = (dev_far + dev_frr) / 2.0

        dev_ni = dev_neg.shape[0]  # number of impostors
        dev_fm = int(round(dev_fmr * dev_ni))  # number of false accepts
        dev_nc = dev_pos.shape[0]  # number of clients
        dev_fnm = int(round(dev_fnmr * dev_nc))  # number of false rejects

233 234 235 236 237 238 239
        dev_fta_str = "%.1f%%" % (100 * dev_fta)
        dev_fmr_str = "%.1f%% (%d/%d)" % (100 * dev_fmr, dev_fm, dev_ni)
        dev_fnmr_str = "%.1f%% (%d/%d)" % (100 * dev_fnmr, dev_fnm, dev_nc)
        dev_far_str = "%.1f%%" % (100 * dev_far)
        dev_frr_str = "%.1f%%" % (100 * dev_frr)
        dev_hter_str = "%.1f%%" % (100 * dev_hter)
        headers = ['' or title, 'Development %s' % dev_file]
240 241 242 243 244 245
        raws = [[self.names[0], dev_fta_str],
                [self.names[1], dev_fmr_str],
                [self.names[2], dev_fnmr_str],
                [self.names[3], dev_far_str],
                [self.names[4], dev_frr_str],
                [self.names[5], dev_hter_str]]
246

247
        if self._eval:
248 249 250 251 252 253 254 255 256
            # computes statistics for the eval set based on the threshold a priori
            eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, threshold)
            eval_far = eval_fmr * (1 - eval_fta)
            eval_frr = eval_fta + eval_fnmr * (1 - eval_fta)
            eval_hter = (eval_far + eval_frr) / 2.0

            eval_ni = eval_neg.shape[0]  # number of impostors
            eval_fm = int(round(eval_fmr * eval_ni))  # number of false accepts
            eval_nc = eval_pos.shape[0]  # number of clients
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
257 258
            # number of false rejects
            eval_fnm = int(round(eval_fnmr * eval_nc))
259 260

            eval_fta_str = "%.1f%%" % (100 * eval_fta)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
261 262 263 264
            eval_fmr_str = "%.1f%% (%d/%d)" % (100 *
                                               eval_fmr, eval_fm, eval_ni)
            eval_fnmr_str = "%.1f%% (%d/%d)" % (100 *
                                                eval_fnmr, eval_fnm, eval_nc)
265 266 267 268 269 270 271 272 273 274 275 276

            eval_far_str = "%.1f%%" % (100 * eval_far)
            eval_frr_str = "%.1f%%" % (100 * eval_frr)
            eval_hter_str = "%.1f%%" % (100 * eval_hter)

            headers.append('Eval. % s' % eval_file)
            raws[0].append(eval_fta_str)
            raws[1].append(eval_fmr_str)
            raws[2].append(eval_fnmr_str)
            raws[3].append(eval_far_str)
            raws[4].append(eval_frr_str)
            raws[5].append(eval_hter_str)
277 278 279 280 281 282 283 284

        click.echo(tabulate(raws, headers, self._tablefmt), file=self.log_file)

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
285

286 287 288 289
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
290

291 292
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
293 294 295 296 297 298
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
299 300 301
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
302
        elif self._axlim is not None and self._axlim[0] is not None:
303 304
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
305 306
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
307 308 309 310
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
311 312
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
313 314
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
315
        self._nb_figs = 2 if self._eval and self._split else 1
316
        self._colors = utils.get_colors(self.n_systems)
317
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
318 319
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
320
        self._titles = ctx.meta.get('titles', []) * 2
321 322 323 324 325
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

326 327
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
328 329 330 331 332 333 334 335 336 337
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
338
            self._ctx.meta else PdfPages(self._output)
339

340
        for i in range(self._nb_figs):
341
            fs = self._ctx.meta.get('figsize')
342 343
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
344
            fig.clear()
345 346

    def end_process(self):
347 348
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
349
        # draw vertical lines
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
369
        # only for plots
370 371 372
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
373 374
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
375 376 377
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
378 379
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
380
                self._set_axis()
381 382 383
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
384
        # do not want to close PDF when running evaluate
385 386 387 388
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
389
    # common protected functions
390 391

    def _label(self, base, name, idx):
392 393
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
394
        if self.n_systems > 1:
395 396 397
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

398
    def _set_axis(self):
399
        if self._axlim is not None:
400
            mpl.axis(self._axlim)
401

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
402

403
class Roc(PlotBase):
404
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
405

406 407
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
408
        self._titles = self._titles or ['ROC dev', 'ROC eval']
409
        self._x_label = self._x_label or 'False Positive Rate'
410
        self._y_label = self._y_label or "1 - False Negative Rate"
411 412 413
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
414
        # custom defaults
415
        if self._axlim is None:
416
            self._axlim = [None, None, -0.05, 1.05]
417

418
    def compute(self, idx, input_scores, input_names):
419 420
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
421 422
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
423 424
        dev_file = input_names[0]
        if self._eval:
425
            eval_neg, eval_pos = neg_list[1], pos_list[1]
426 427
            eval_file = input_names[1]

428
        mpl.figure(1)
429
        if self._eval:
430 431
            plot.roc_for_far(
                dev_neg, dev_pos,
432
                far_values=plot.log_values(self._min_dig or -4),
433
                CAR=self._semilogx,
434
                color=self._colors[idx], linestyle=self._linestyles[idx],
435
                label=self._label('dev', dev_file, idx)
436 437 438 439
            )
            if self._split:
                mpl.figure(2)

440
            linestyle = '--' if not self._split else self._linestyles[idx]
441
            plot.roc_for_far(
442 443
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
444
                CAR=self._semilogx,
445
                color=self._colors[idx],
446
                label=self._label('eval', eval_file, idx)
447
            )
448
            if self._far_at is not None:
449
                from .. import farfrr
450
                for line in self._far_at:
451
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
452 453
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
454
                    eval_fnmr = 1 - eval_fnmr
455 456
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
457
        else:
458 459
            plot.roc_for_far(
                dev_neg, dev_pos,
460
                far_values=plot.log_values(self._min_dig or -4),
461
                CAR=self._semilogx,
462
                color=self._colors[idx], linestyle=self._linestyles[idx],
463
                label=self._label('dev', dev_file, idx)
464 465
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
466

467 468
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
469

470 471
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
472
        self._titles = self._titles or ['DET dev', 'DET eval']
473 474
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
475
        self._legend_loc = self._legend_loc or 'upper right'
476 477
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
478
        # custom defaults here
479 480
        if self._x_rotation is None:
            self._x_rotation = 50
481

482 483 484 485 486 487
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

488
    def compute(self, idx, input_scores, input_names):
489 490
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
491 492
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
493 494
        dev_file = input_names[0]
        if self._eval:
495
            eval_neg, eval_pos = neg_list[1], pos_list[1]
496 497
            eval_file = input_names[1]

498
        mpl.figure(1)
499
        if self._eval and eval_neg is not None:
500 501
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
502
                linestyle=self._linestyles[idx],
503
                label=self._label('development', dev_file, idx)
504 505 506
            )
            if self._split:
                mpl.figure(2)
507
            linestyle = '--' if not self._split else self._linestyles[idx]
508
            plot.det(
509
                eval_neg, eval_pos, self._points, color=self._colors[idx],
510
                linestyle=linestyle,
511
                label=self._label('eval', eval_file, idx)
512
            )
513 514 515 516
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
517 518
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
519 520 521
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
522 523 524
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
525
                linestyle=self._linestyles[idx],
526
                label=self._label('development', dev_file, idx)
527 528
            )

529
    def _set_axis(self):
530
        plot.det_axis(self._axlim)
531

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
532

533 534
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
535

536
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
537
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
538
        if self._min_arg != 2:
539
            raise click.UsageError("EPC requires dev and eval score files")
540
        self._titles = self._titles or ['EPC'] * 2
541
        self._x_label = self._x_label or r'$\alpha$'
542
        self._y_label = self._y_label or hter + ' (%)'
543
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
544
        self._eval = True  # always eval data with EPC
545
        self._split = False
546
        self._nb_figs = 1
547
        self._far_at = None
548

549
    def compute(self, idx, input_scores, input_names):
550
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
551 552
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
553 554
        dev_file = input_names[0]
        if self._eval:
555
            eval_neg, eval_pos = neg_list[1], pos_list[1]
556 557
            eval_file = input_names[1]

558
        plot.epc(
559
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
560
            color=self._colors[idx], linestyle=self._linestyles[idx],
561
            label=self._label(
562
                'curve', dev_file + "_" + eval_file, idx
563
            )
564 565
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
566

567
class Hist(PlotBase):
568
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
569

570
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
571
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
572 573 574
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
575
            self._nbins, nhist_per_system, 'n_bins',
576
            'histograms')
577
        self._thres = ctx.meta.get('thres')
578 579
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
580
        self._criterion = ctx.meta.get('criterion')
581
        # no vertical (threshold) is displayed
582
        self._no_line = ctx.meta.get('no_line', False)
583
        # subplot grid
584 585
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
586
        # do not display dev histo
587
        self._hide_dev = ctx.meta.get('hide_dev', False)
588 589 590
        if self._hide_dev and not self._eval:
            raise click.BadParameter("You can only use --hide-dev along with --eval")

591
        # dev hist are displayed next to eval hist
592
        self._ncols *= 1 if self._hide_dev or not self._eval else 2
593
        self._nlegends = ctx.meta.get('legends_ncol', 3)
594
        self._legend_loc = self._legend_loc or 'upper center'
595
        # number of subplot on one page
596
        self._step_print = int(self._nrows * self._ncols)
597
        self._title_base = 'Scores'
598
        self._y_label = 'Probability density'
599
        self._x_label = 'Score values'
600
        self._end_setup_plot = False
601 602
        if self._legends is not None and len(self._legends) == self.n_systems \
           and not self._hide_dev:
603
            # use same legend for dev and eval if needed
604 605
            self._legends = [x for pair in zip(self._legends,self._legends)
                             for x in pair]
606

607
    def compute(self, idx, input_scores, input_names):
608
        ''' Draw histograms of negative and positive scores.'''
609
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
610
            self._get_neg_pos_thres(idx, input_scores, input_names)
611 612 613
        idx *= 1 if self._hide_dev or not self._eval else 2

        if not self._hide_dev or not self._eval:
614 615
            self._print_subplot(idx, dev_neg, dev_pos, threshold,
                                not self._no_line, False)
616 617 618 619

        idx += 1 if self._eval and not self._hide_dev else 0
        if self._eval:
            self._print_subplot(idx, eval_neg, eval_pos, threshold,
620
                                not self._no_line, True)
621

622
    def _print_subplot(self, idx, neg, pos, threshold, draw_line, evaluation):
623
        ''' print a subplot for the given score and subplot index'''
624 625 626 627 628 629 630
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
631
        # rest to be printed
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
632 633
        rest_print = self.n_systems - \
            int(idx / self._step_print) * self._step_print
634
        if n + self._ncols >= min(self._step_print, rest_print):
635
            axis.set_xlabel(self._x_label)
636
        dflt_title = "Eval scores" if evaluation else "Dev scores"
637
        axis.set_title(self._get_title(idx, dflt_title))
638
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
639
            '' if self._criterion is None else
640 641
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
642
        if draw_line:
643
            self._lines(threshold, label, neg, pos, idx)
644

645
        mult = 2 if self._eval and not self._hide_dev else 1
646 647
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
648
        if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
649 650
            # print legend on the page
            self.plot_legends()
651
            mpl.tight_layout()
652
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
653 654
            mpl.clf()
            mpl.figure()
655

656
    def _get_title(self, idx, dflt=None):
657
        ''' Get the histo title for the given idx'''
658
        title = self._legends[idx] if self._legends is not None \
659
            and idx < len(self._legends) else dflt
660
        title = title or self._title_base
661 662
        title = '' if title is not None and not title.replace(
            ' ', '') else title
663
        return title or ''
664

665 666
    def plot_legends(self):
        ''' Print legend on current page'''
667 668 669
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
670 671 672 673 674 675 676
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

677 678
        if self._disp_legend:
            mpl.gcf().legend(
679
                lines, labels, loc=self._legend_loc, fancybox=True,
680
                framealpha=0.5, ncol=self._nlegends,
681
                bbox_to_anchor=(0.55, 1.1),
682
            )
683

684
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
685
        ''' Get scores and threshod for the given system at index idx'''
686 687
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
688 689 690 691 692 693 694 695 696 697
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
698
        # can have several files for one system
699 700
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
701 702
        eval_neg = eval_pos = None
        if self._eval:
703 704
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
705

706
        threshold = utils.get_thres(
707
            self._criterion, dev_neg[0], dev_pos[0]
708
        ) if self._thres is None else self._thres[idx]
709
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
710

711
    def _density_hist(self, scores, n, **kwargs):
712
        ''' Plots one density histo'''
713
        n, bins, patches = mpl.hist(
714
            scores, density=True,
715
            bins=self._nbins[n],
716
            **kwargs
717 718 719
        )
        return (n, bins, patches)

720 721
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
722
        ''' Plots vertical line at threshold '''
723
        label = label or 'Threshold'
724 725 726 727
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
728
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
729 730

    def _setup_hist(self, neg, pos):
731 732 733 734 735
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
736
        self._density_hist(
737 738
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
739 740
        )
        self._density_hist(
741 742
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
743
        )