figure.py 33.3 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
import numpy
9 10 11 12 13
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import (far_threshold, plot, utils, ppndf)
15

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
16

17 18 19 20 21 22 23 24 25 26 27 28
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


29 30 31 32 33 34 35 36 37 38
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
39 40
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

41
    def __init__(self, ctx, scores, evaluation, func_load):
42 43 44 45 46 47 48
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
49 50 51 52
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
53 54 55 56 57
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
58
        self._legends = ctx.meta.get('legends')
59
        self._eval = evaluation
60
        self._min_arg = ctx.meta.get('min_arg', 1)
61 62 63 64 65
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
66 67
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
68
                                     "number of systems")
69 70 71 72 73 74 75 76 77 78

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
79
        # init matplotlib, log files, ...
80
        self.init_process()
81 82
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
83 84
        # Note that more than one dev or eval scores score can be passed to
        # each system
85
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
86
            # load scores for each system: get the corresponding arrays and
87
            # base-name of files
88
            input_scores, input_names = self._load_files(
89 90 91 92 93 94
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
95
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
96 97
            )
            self.compute(idx, input_scores, input_names)
98
        # setup final configuration, plotting properties, ...
99 100
        self.end_process()

101
    # protected functions that need to be overwritten
102 103
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
104
        before iterating through the different systems.
105 106 107
        Should reimplemented in derived classes"""
        pass

108
    # Main computations are done here in the subclasses
109
    @abstractmethod
110
    def compute(self, idx, input_scores, input_names):
111
        """Compute metrics or plots from the given scores provided by
112 113 114 115 116 117 118
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
119 120 121 122
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
123 124
        """
        pass
125 126 127 128 129 130 131
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

132
    # Things to do after the main iterative computations are done
133 134
    @abstractmethod
    def end_process(self):
135
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
136
        after iterating through the different systems.
137
        Should reimplemented in derived classes"""
138 139
        pass

140
    # common protected functions
141

142 143
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
144 145 146

        Returns
        -------
147 148 149 150 151
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
152
        '''
153 154 155 156 157 158
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
159

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
160

161 162 163 164 165 166 167 168
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
169

170
    def __init__(self, ctx, scores, evaluation, func_load,
171 172 173
                 names=('NaNs Rate', 'False Positive Rate', 'False Negative Rate',
                        'False Accept Rate', 'False Reject Rate',
                        'Half Total Error Rate')):
174
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
175
        self.names = names
176 177 178 179
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
180
        if self._thres is not None:
181
            if len(self._thres) == 1:
182 183
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
184
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
185
                    '#thresholds must be the same as #systems (%d)'
186
                    % len(self.n_systems)
187
                )
188 189
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
190 191 192 193
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

194 195 196
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    def _numbers(self, neg, pos, threshold, fta):
        from .. import farfrr
        fmr, fnmr = farfrr(neg, pos, threshold)
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)
        hter = (far + frr) / 2.0

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects
        return fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc

    def _strings(self, fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc):
        fta_str = "%.1f%%" % (100 * fta)
        fmr_str = "%.1f%% (%d/%d)" % (100 * fmr, fm, ni)
        fnmr_str = "%.1f%% (%d/%d)" % (100 * fnmr, fnm, nc)
        far_str = "%.1f%%" % (100 * far)
        frr_str = "%.1f%%" % (100 * frr)
        hter_str = "%.1f%%" % (100 * hter)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

220
    def compute(self, idx, input_scores, input_names):
221
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
222
        given system inputs'''
223 224 225 226 227 228 229
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

230
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
231
            if self._thres is None else self._thres[idx]
232

233
        title = self._legends[idx] if self._legends is not None else None
234
        if self._thres is None:
235
            far_str = ''
236
            if self._criterion == 'far' and self._far is not None:
237
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
238
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
239 240 241
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
242 243
                       file=self.log_file)
        else:
244
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
245
                       "Development set `%s`: %e"
246
                       % (dev_file or title, threshold), file=self.log_file)
247

248 249 250
        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(*self._numbers(
                dev_neg, dev_pos, threshold, dev_fta))
251
        headers = ['' or title, 'Development %s' % dev_file]
252 253 254 255 256 257
        rows = [[self.names[0], fta_str],
                [self.names[1], fmr_str],
                [self.names[2], fnmr_str],
                [self.names[3], far_str],
                [self.names[4], frr_str],
                [self.names[5], hter_str]]
258

259
        if self._eval:
260 261 262 263 264
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(*self._numbers(
                    eval_neg, eval_pos, threshold, eval_fta))
265 266

            headers.append('Eval. % s' % eval_file)
267 268 269 270 271 272
            rows[0].append(fta_str)
            rows[1].append(fmr_str)
            rows[2].append(fnmr_str)
            rows[3].append(far_str)
            rows[4].append(frr_str)
            rows[5].append(hter_str)
273

274
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
275 276 277 278 279 280

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
281

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
class MultiMetrics(Metrics):
    '''Computes average of metrics based on several protocols (cross
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
    '''

    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('NaNs Rate', 'False Positive Rate',
                        'False Negative Rate', 'False Accept Rate',
                        'False Reject Rate', 'Half Total Error Rate')):
        super(MultiMetrics, self).__init__(
            ctx, scores, evaluation, func_load, names=names)

        self.headers = ['Methods'] + list(self.names)
        if self._eval:
            self.headers.insert(1, self.names[5] + ' (dev)')
        self.rows = []

    def _strings(self, metrics):
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
        fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
        fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
        fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
        far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
        frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
        hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
        '''Computes the average of metrics over several protocols.'''
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
            threshold = self.get_thres(self._criterion, neg, pos, self._far) \
                if self._thres is None else self._thres[idx]
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
                self._eval_metrics.append(
                    self._numbers(neg, pos, threshold, fta))
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(self._dev_metrics)

        if self._eval:
            self.rows.append([title, hter_str])
        else:
            self.rows.append([title, fta_str, fmr_str, fnmr_str,
                              far_str, frr_str, hter_str])

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(self._eval_metrics)

            self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
                                  far_str, frr_str, hter_str])

    def end_process(self):
        click.echo(tabulate(self.rows, self.headers,
                            self._tablefmt), file=self.log_file)
        super(MultiMetrics, self).end_process()


367 368 369 370
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
371

372 373
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
374 375 376 377 378 379
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
380 381 382
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
383
        elif self._axlim is not None and self._axlim[0] is not None:
384 385
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
386 387
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
388 389 390 391
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
392 393
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
394 395
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
396
        self._nb_figs = 2 if self._eval and self._split else 1
397
        self._colors = utils.get_colors(self.n_systems)
398
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
399 400
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
401
        self._titles = ctx.meta.get('titles', []) * 2
402 403 404 405 406
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

407 408
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
409 410 411 412 413 414 415 416 417 418
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
419
            self._ctx.meta else PdfPages(self._output)
420

421
        for i in range(self._nb_figs):
422
            fs = self._ctx.meta.get('figsize')
423 424
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
425
            fig.clear()
426 427

    def end_process(self):
428 429
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
430
        # draw vertical lines
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
450
        # only for plots
451 452 453
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
454 455
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
456 457 458
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
459 460
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
461
                self._set_axis()
462 463 464
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
465
        # do not want to close PDF when running evaluate
466 467 468 469
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
470
    # common protected functions
471 472

    def _label(self, base, name, idx):
473 474
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
475
        if self.n_systems > 1:
476 477 478
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

479
    def _set_axis(self):
480
        if self._axlim is not None:
481
            mpl.axis(self._axlim)
482

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
483

484
class Roc(PlotBase):
485
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
486

487 488
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
489
        self._titles = self._titles or ['ROC dev.', 'ROC eval.']
490
        self._x_label = self._x_label or 'False Positive Rate'
491
        self._y_label = self._y_label or "1 - False Negative Rate"
492 493 494
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
495
        # custom defaults
496
        if self._axlim is None:
497
            self._axlim = [None, None, -0.05, 1.05]
498

499
    def compute(self, idx, input_scores, input_names):
500 501
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
502 503
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
504 505
        dev_file = input_names[0]
        if self._eval:
506
            eval_neg, eval_pos = neg_list[1], pos_list[1]
507 508
            eval_file = input_names[1]

509
        mpl.figure(1)
510
        if self._eval:
511 512
            plot.roc_for_far(
                dev_neg, dev_pos,
513
                far_values=plot.log_values(self._min_dig or -4),
514
                CAR=self._semilogx,
515
                color=self._colors[idx], linestyle=self._linestyles[idx],
516
                label=self._label('dev', dev_file, idx)
517 518 519 520
            )
            if self._split:
                mpl.figure(2)

521
            linestyle = '--' if not self._split else self._linestyles[idx]
522
            plot.roc_for_far(
523 524
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
525
                CAR=self._semilogx,
526
                color=self._colors[idx],
527
                label=self._label('eval.', eval_file, idx)
528
            )
529
            if self._far_at is not None:
530
                from .. import farfrr
531
                for line in self._far_at:
532
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
533 534
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
535
                    eval_fnmr = 1 - eval_fnmr
536 537
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
538
        else:
539 540
            plot.roc_for_far(
                dev_neg, dev_pos,
541
                far_values=plot.log_values(self._min_dig or -4),
542
                CAR=self._semilogx,
543
                color=self._colors[idx], linestyle=self._linestyles[idx],
544
                label=self._label('dev', dev_file, idx)
545 546
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
547

548 549
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
550

551 552
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
553
        self._titles = self._titles or ['DET dev.', 'DET eval.']
554 555
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
556
        self._legend_loc = self._legend_loc or 'upper right'
557 558
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
559
        # custom defaults here
560 561
        if self._x_rotation is None:
            self._x_rotation = 50
562

563 564 565 566 567 568
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

569
    def compute(self, idx, input_scores, input_names):
570 571
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
572 573
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
574 575
        dev_file = input_names[0]
        if self._eval:
576
            eval_neg, eval_pos = neg_list[1], pos_list[1]
577 578
            eval_file = input_names[1]

579
        mpl.figure(1)
580
        if self._eval and eval_neg is not None:
581 582
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
583
                linestyle=self._linestyles[idx],
584
                label=self._label('development', dev_file, idx)
585 586 587
            )
            if self._split:
                mpl.figure(2)
588
            linestyle = '--' if not self._split else self._linestyles[idx]
589
            plot.det(
590
                eval_neg, eval_pos, self._points, color=self._colors[idx],
591
                linestyle=linestyle,
592
                label=self._label('eval.', eval_file, idx)
593
            )
594 595 596 597
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
598 599
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
600 601 602
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
603 604 605
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
606
                linestyle=self._linestyles[idx],
607
                label=self._label('development', dev_file, idx)
608 609
            )

610
    def _set_axis(self):
611
        plot.det_axis(self._axlim)
612

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
613

614 615
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
616

617
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
618
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
619
        if self._min_arg != 2:
620
            raise click.UsageError("EPC requires dev. and eval. score files")
621
        self._titles = self._titles or ['EPC'] * 2
622
        self._x_label = self._x_label or r'$\alpha$'
623
        self._y_label = self._y_label or hter + ' (%)'
624
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
625
        self._eval = True  # always eval data with EPC
626
        self._split = False
627
        self._nb_figs = 1
628
        self._far_at = None
629

630
    def compute(self, idx, input_scores, input_names):
631
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
632 633
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
634 635
        dev_file = input_names[0]
        if self._eval:
636
            eval_neg, eval_pos = neg_list[1], pos_list[1]
637 638
            eval_file = input_names[1]

639
        plot.epc(
640
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
641
            color=self._colors[idx], linestyle=self._linestyles[idx],
642
            label=self._label(
643
                'curve', dev_file + "_" + eval_file, idx
644
            )
645 646
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
647

648
class Hist(PlotBase):
649
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
650

651
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
652
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
653 654 655
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
656
            self._nbins, nhist_per_system, 'n_bins',
657
            'histograms')
658
        self._thres = ctx.meta.get('thres')
659 660
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
661
        self._criterion = ctx.meta.get('criterion')
662
        # no vertical (threshold) is displayed
663
        self._no_line = ctx.meta.get('no_line', False)
664
        # subplot grid
665 666
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
667
        # do not display dev histo
668
        self._hide_dev = ctx.meta.get('hide_dev', False)
669
        if self._hide_dev and not self._eval:
670 671
            raise click.BadParameter(
                "You can only use --hide-dev along with --eval")
672

673
        # dev hist are displayed next to eval hist
674
        self._ncols *= 1 if self._hide_dev or not self._eval else 2
675
        self._nlegends = ctx.meta.get('legends_ncol', 3)
676
        self._legend_loc = self._legend_loc or 'upper center'
677
        # number of subplot on one page
678
        self._step_print = int(self._nrows * self._ncols)
679
        self._title_base = 'Scores'
680
        self._y_label = 'Probability density'
681
        self._x_label = 'Score values'
682
        self._end_setup_plot = False
683 684 685
        # overide _titles of PlotBase
        self._titles = ctx.meta.get('titles')
        if self._titles is not None and len(self._titles) == self.n_systems \
686
           and not self._hide_dev:
687
            # use same legend for dev and eval if needed
688 689
            self._titles = [x for pair in zip(self._titles, self._titles)
                            for x in pair]
690

691
    def compute(self, idx, input_scores, input_names):
692
        ''' Draw histograms of negative and positive scores.'''
693
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
694
            self._get_neg_pos_thres(idx, input_scores, input_names)
695 696 697
        idx *= 1 if self._hide_dev or not self._eval else 2

        if not self._hide_dev or not self._eval:
698 699
            self._print_subplot(idx, dev_neg, dev_pos, threshold,
                                not self._no_line, False)
700 701 702 703

        idx += 1 if self._eval and not self._hide_dev else 0
        if self._eval:
            self._print_subplot(idx, eval_neg, eval_pos, threshold,
704
                                not self._no_line, True)
705

706
    def _print_subplot(self, idx, neg, pos, threshold, draw_line, evaluation):
707
        ''' print a subplot for the given score and subplot index'''
708 709 710 711 712 713 714
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
715
        # rest to be printed
716 717 718
        rest_print = self.n_systems * (2 if self._eval and not self._hide_dev
                                       else 1) - int(idx / self._step_print) \
                                    * self._step_print
719
        if n + self._ncols >= min(self._step_print, rest_print):
720
            axis.set_xlabel(self._x_label)
721
        dflt_title = "Eval. scores" if evaluation else "Dev. scores"
722 723
        if self.n_systems == 1 and (not self._eval or self._hide_dev):
            dflt_title = " "
724
        axis.set_title(self._get_title(idx, dflt_title))
725
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
726
            '' if self._criterion is None else
727 728
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
729
        if draw_line:
730
            self._lines(threshold, label, neg, pos, idx)
731

732
        mult = 2 if self._eval and not self._hide_dev else 1
733 734
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
735
        if self._step_print == sub_plot_idx or idx == self.n_systems * mult - 1:
736 737
            # print legend on the page
            self.plot_legends()
738
            mpl.tight_layout()
739
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
740 741
            mpl.clf()
            mpl.figure()
742

743
    def _get_title(self, idx, dflt=None):
744
        ''' Get the histo title for the given idx'''
745 746
        title = self._titles[idx] if self._titles is not None \
            and idx < len(self._titles) else dflt
747
        title = title or self._title_base
748 749
        title = '' if title is not None and not title.replace(
            ' ', '') else title
750
        return title or ''
751

752 753
    def plot_legends(self):
        ''' Print legend on current page'''
754 755 756
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
757 758 759 760 761 762 763
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

764 765
        if self._disp_legend:
            mpl.gcf().legend(
766
                lines, labels, loc=self._legend_loc, fancybox=True,
767
                framealpha=0.5, ncol=self._nlegends,
768
                bbox_to_anchor=(0.55, 1.1),
769
            )
770

771
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
772
        ''' Get scores and threshod for the given system at index idx'''
773 774
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
775 776 777 778 779 780 781 782 783 784
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
785
        # can have several files for one system
786 787
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
788 789
        eval_neg = eval_pos = None
        if self._eval:
790 791
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
792

793
        threshold = utils.get_thres(
794
            self._criterion, dev_neg[0], dev_pos[0]
795
        ) if self._thres is None else self._thres[idx]
796
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
797

798
    def _density_hist(self, scores, n, **kwargs):
799
        ''' Plots one density histo'''
800
        n, bins, patches = mpl.hist(
801
            scores, density=True,
802
            bins=self._nbins[n],
803
            **kwargs
804 805 806
        )
        return (n, bins, patches)

807 808
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
809
        ''' Plots vertical line at threshold '''
810
        label = label or 'Threshold'
811 812 813 814
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
815
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
816 817

    def _setup_hist(self, neg, pos):
818 819 820 821 822
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
823
        self._density_hist(
824 825
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
826 827
        )
        self._density_hist(
828 829
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
830
        )