figure.py 33.5 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
import numpy
9 10 11 12 13
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import (far_threshold, plot, utils, ppndf)
15

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
16

17 18 19 20 21 22 23 24 25 26 27 28
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


29 30 31 32 33 34 35 36 37 38
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
39 40
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

41
    def __init__(self, ctx, scores, evaluation, func_load):
42 43 44 45 46 47 48
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
49 50 51 52
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
53 54 55 56 57
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
58
        self._legends = ctx.meta.get('legends')
59
        self._eval = evaluation
60
        self._min_arg = ctx.meta.get('min_arg', 1)
61 62 63 64 65
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
66 67
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
68
                                     "number of systems")
69 70 71 72 73 74 75 76 77 78

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
79
        # init matplotlib, log files, ...
80
        self.init_process()
81 82
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
83 84
        # Note that more than one dev or eval scores score can be passed to
        # each system
85
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
86
            # load scores for each system: get the corresponding arrays and
87
            # base-name of files
88
            input_scores, input_names = self._load_files(
89 90 91 92 93 94
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
95
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
96 97
            )
            self.compute(idx, input_scores, input_names)
98
        # setup final configuration, plotting properties, ...
99 100
        self.end_process()

101
    # protected functions that need to be overwritten
102 103
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
104
        before iterating through the different systems.
105 106 107
        Should reimplemented in derived classes"""
        pass

108
    # Main computations are done here in the subclasses
109
    @abstractmethod
110
    def compute(self, idx, input_scores, input_names):
111
        """Compute metrics or plots from the given scores provided by
112 113 114 115 116 117 118
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
119 120 121 122
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
123 124
        """
        pass
125 126 127 128 129 130 131
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

132
    # Things to do after the main iterative computations are done
133 134
    @abstractmethod
    def end_process(self):
135
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
136
        after iterating through the different systems.
137
        Should reimplemented in derived classes"""
138 139
        pass

140
    # common protected functions
141

142 143
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
144 145 146

        Returns
        -------
147 148 149 150 151
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
152
        '''
153 154 155 156 157 158
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
159

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
160

161 162 163 164 165 166 167 168
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
169

170
    def __init__(self, ctx, scores, evaluation, func_load,
171 172 173
                 names=('NaNs Rate', 'False Positive Rate', 'False Negative Rate',
                        'False Accept Rate', 'False Reject Rate',
                        'Half Total Error Rate')):
174
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
175
        self.names = names
176 177 178 179
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
180
        if self._thres is not None:
181
            if len(self._thres) == 1:
182 183
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
184
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
185
                    '#thresholds must be the same as #systems (%d)'
186
                    % len(self.n_systems)
187
                )
188 189
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
190 191 192 193
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

194 195 196
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

197 198 199
    def _numbers(self, neg, pos, threshold, fta):
        from .. import farfrr
        fmr, fnmr = farfrr(neg, pos, threshold)
200
        hter = (fmr + fnmr) / 2.0
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects
        return fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc

    def _strings(self, fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc):
        fta_str = "%.1f%%" % (100 * fta)
        fmr_str = "%.1f%% (%d/%d)" % (100 * fmr, fm, ni)
        fnmr_str = "%.1f%% (%d/%d)" % (100 * fnmr, fnm, nc)
        far_str = "%.1f%%" % (100 * far)
        frr_str = "%.1f%%" % (100 * frr)
        hter_str = "%.1f%%" % (100 * hter)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

220
    def compute(self, idx, input_scores, input_names):
221
        ''' Compute metrics thresholds and tables (FAR, FMR, FNMR, HTER) for
222
        given system inputs'''
223 224 225 226 227 228 229
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]
            eval_file = input_names[1]

230
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
231
            if self._thres is None else self._thres[idx]
232

233
        title = self._legends[idx] if self._legends is not None else None
234
        if self._thres is None:
235
            far_str = ''
236
            if self._criterion == 'far' and self._far is not None:
237
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
238
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
239 240 241
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
242 243
                       file=self.log_file)
        else:
244
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
245
                       "Development set `%s`: %e"
246
                       % (dev_file or title, threshold), file=self.log_file)
247

248 249 250
        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(*self._numbers(
                dev_neg, dev_pos, threshold, dev_fta))
251
        headers = ['' or title, 'Development %s' % dev_file]
252 253 254 255 256 257
        rows = [[self.names[0], fta_str],
                [self.names[1], fmr_str],
                [self.names[2], fnmr_str],
                [self.names[3], far_str],
                [self.names[4], frr_str],
                [self.names[5], hter_str]]
258

259
        if self._eval:
260 261 262 263 264
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(*self._numbers(
                    eval_neg, eval_pos, threshold, eval_fta))
265 266

            headers.append('Eval. % s' % eval_file)
267 268 269 270 271 272
            rows[0].append(fta_str)
            rows[1].append(fmr_str)
            rows[2].append(fnmr_str)
            rows[3].append(far_str)
            rows[4].append(frr_str)
            rows[5].append(hter_str)
273

274
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
275 276 277 278 279 280

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
281

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
class MultiMetrics(Metrics):
    '''Computes average of metrics based on several protocols (cross
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
    '''

    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('NaNs Rate', 'False Positive Rate',
                        'False Negative Rate', 'False Accept Rate',
                        'False Reject Rate', 'Half Total Error Rate')):
        super(MultiMetrics, self).__init__(
            ctx, scores, evaluation, func_load, names=names)

        self.headers = ['Methods'] + list(self.names)
        if self._eval:
            self.headers.insert(1, self.names[5] + ' (dev)')
        self.rows = []

    def _strings(self, metrics):
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
        fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
        fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
        fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
        far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
        frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
        hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
        '''Computes the average of metrics over several protocols.'''
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
            threshold = self.get_thres(self._criterion, neg, pos, self._far) \
                if self._thres is None else self._thres[idx]
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
                self._eval_metrics.append(
                    self._numbers(neg, pos, threshold, fta))
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(self._dev_metrics)

        if self._eval:
            self.rows.append([title, hter_str])
        else:
            self.rows.append([title, fta_str, fmr_str, fnmr_str,
                              far_str, frr_str, hter_str])

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(self._eval_metrics)

            self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
                                  far_str, frr_str, hter_str])

    def end_process(self):
        click.echo(tabulate(self.rows, self.headers,
                            self._tablefmt), file=self.log_file)
        super(MultiMetrics, self).end_process()


367 368 369 370
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
371

372 373
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
374 375 376 377 378 379
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
380 381 382
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
383
        elif self._axlim is not None and self._axlim[0] is not None:
384 385
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
386 387
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
388 389 390 391
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
392 393
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
394 395
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
396
        self._nb_figs = 2 if self._eval and self._split else 1
397
        self._colors = utils.get_colors(self.n_systems)
398
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
399 400
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
401
        self._titles = ctx.meta.get('titles', []) * 2
402 403 404 405 406
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

407 408
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
409 410 411 412 413 414 415 416 417 418
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
419
            self._ctx.meta else PdfPages(self._output)
420

421
        for i in range(self._nb_figs):
422
            fs = self._ctx.meta.get('figsize')
423 424
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
425
            fig.clear()
426 427

    def end_process(self):
428 429
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
430
        # draw vertical lines
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
450
        # only for plots
451 452 453
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
454 455
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
456 457 458
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
459 460
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
461
                self._set_axis()
462 463 464
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
465
        # do not want to close PDF when running evaluate
466 467 468 469
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
470
    # common protected functions
471 472

    def _label(self, base, name, idx):
473 474
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
475
        if self.n_systems > 1:
476 477 478
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

479
    def _set_axis(self):
480
        if self._axlim is not None:
481
            mpl.axis(self._axlim)
482

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
483

484
class Roc(PlotBase):
485
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
486

487 488
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
489
        self._titles = self._titles or ['ROC dev.', 'ROC eval.']
490
        self._x_label = self._x_label or 'False Positive Rate'
491
        self._y_label = self._y_label or "1 - False Negative Rate"
492 493 494
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
495
        # custom defaults
496
        if self._axlim is None:
497
            self._axlim = [None, None, -0.05, 1.05]
498

499
    def compute(self, idx, input_scores, input_names):
500 501
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
502 503
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
504 505
        dev_file = input_names[0]
        if self._eval:
506
            eval_neg, eval_pos = neg_list[1], pos_list[1]
507 508
            eval_file = input_names[1]

509
        mpl.figure(1)
510
        if self._eval:
511 512
            plot.roc_for_far(
                dev_neg, dev_pos,
513
                far_values=plot.log_values(self._min_dig or -4),
514
                CAR=self._semilogx,
515
                color=self._colors[idx], linestyle=self._linestyles[idx],
516
                label=self._label('dev', dev_file, idx)
517 518 519 520
            )
            if self._split:
                mpl.figure(2)

521
            linestyle = '--' if not self._split else self._linestyles[idx]
522
            plot.roc_for_far(
523 524
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
525
                CAR=self._semilogx,
526
                color=self._colors[idx],
527
                label=self._label('eval.', eval_file, idx)
528
            )
529
            if self._far_at is not None:
530
                from .. import farfrr
531
                for line in self._far_at:
532
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
533 534
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
535
                    eval_fnmr = 1 - eval_fnmr
536 537
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
538
        else:
539 540
            plot.roc_for_far(
                dev_neg, dev_pos,
541
                far_values=plot.log_values(self._min_dig or -4),
542
                CAR=self._semilogx,
543
                color=self._colors[idx], linestyle=self._linestyles[idx],
544
                label=self._label('dev', dev_file, idx)
545 546
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
547

548 549
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
550

551 552
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
553
        self._titles = self._titles or ['DET dev.', 'DET eval.']
554 555
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
556
        self._legend_loc = self._legend_loc or 'upper right'
557 558
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
559
        # custom defaults here
560 561
        if self._x_rotation is None:
            self._x_rotation = 50
562

563 564 565 566 567 568
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

569
    def compute(self, idx, input_scores, input_names):
570 571
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
572 573
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
574 575
        dev_file = input_names[0]
        if self._eval:
576
            eval_neg, eval_pos = neg_list[1], pos_list[1]
577 578
            eval_file = input_names[1]

579
        mpl.figure(1)
580
        if self._eval and eval_neg is not None:
581 582
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
583
                linestyle=self._linestyles[idx],
584
                label=self._label('development', dev_file, idx)
585 586 587
            )
            if self._split:
                mpl.figure(2)
588
            linestyle = '--' if not self._split else self._linestyles[idx]
589
            plot.det(
590
                eval_neg, eval_pos, self._points, color=self._colors[idx],
591
                linestyle=linestyle,
592
                label=self._label('eval.', eval_file, idx)
593
            )
594 595 596 597
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
598 599
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
600 601 602
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
603 604 605
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
606
                linestyle=self._linestyles[idx],
607
                label=self._label('development', dev_file, idx)
608 609
            )

610
    def _set_axis(self):
611
        plot.det_axis(self._axlim)
612

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
613

614 615
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
616

617
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
618
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
619
        if self._min_arg != 2:
620
            raise click.UsageError("EPC requires dev. and eval. score files")
621
        self._titles = self._titles or ['EPC'] * 2
622
        self._x_label = self._x_label or r'$\alpha$'
623
        self._y_label = self._y_label or hter + ' (%)'
624
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
625
        self._eval = True  # always eval data with EPC
626
        self._split = False
627
        self._nb_figs = 1
628
        self._far_at = None
629

630
    def compute(self, idx, input_scores, input_names):
631
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
632 633
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
634 635
        dev_file = input_names[0]
        if self._eval:
636
            eval_neg, eval_pos = neg_list[1], pos_list[1]
637 638
            eval_file = input_names[1]

639
        plot.epc(
640
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
641
            color=self._colors[idx], linestyle=self._linestyles[idx],
642
            label=self._label(
643
                'curve', dev_file + "_" + eval_file, idx
644
            )
645 646
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
647

648
class Hist(PlotBase):
649
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
650

651
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
652
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
653 654 655
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
656
            self._nbins, nhist_per_system, 'n_bins',
657
            'histograms')
658
        self._thres = ctx.meta.get('thres')
659 660
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
661
        self._criterion = ctx.meta.get('criterion')
662
        # no vertical (threshold) is displayed
663
        self._no_line = ctx.meta.get('no_line', False)
664
        # subplot grid
665 666
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
667
        # do not display dev histo
668
        self._hide_dev = ctx.meta.get('hide_dev', False)
669
        if self._hide_dev and not self._eval:
670 671
            raise click.BadParameter(
                "You can only use --hide-dev along with --eval")
672
        # dev hist are displayed next to eval hist
673
        self._nrows *= 1 if self._hide_dev or not self._eval else 2
674
        self._nlegends = ctx.meta.get('legends_ncol', 3)
675
        self._legend_loc = self._legend_loc or 'upper center'
676
        # number of subplot on one page
677
        self._step_print = int(self._nrows * self._ncols)
678
        self._title_base = 'Scores'
679
        self._y_label = 'Probability density'
680
        self._x_label = 'Score values'
681
        self._end_setup_plot = False
682
        # overide _titles of PlotBase
683
        self._titles = ctx.meta.get('titles', []) * 2
684

685
    def compute(self, idx, input_scores, input_names):
686
        ''' Draw histograms of negative and positive scores.'''
687
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
688
            self._get_neg_pos_thres(idx, input_scores, input_names)
689 690 691 692 693 694 695 696
        # keep id of the current system
        sys = idx
        # if the id of the current system does not match the id of the plot, 
        # change it
        if not self._hide_dev and self._eval:
            row = int(idx / self._ncols) * 2
            col = idx % self._ncols
            idx = col + self._ncols * row
697 698

        if not self._hide_dev or not self._eval:
699
            self._print_subplot(idx, sys, dev_neg, dev_pos, threshold,
700
                                not self._no_line, False)
701 702

        if self._eval:
703 704
            idx += self._ncols if not self._hide_dev else 0
            self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
705
                                not self._no_line, True)
706

707
    def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation):
708
        ''' print a subplot for the given score and subplot index'''
709 710 711 712 713 714 715
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
716 717 718
        # systems per page
        sys_per_page = self._step_print / (1 if self._hide_dev or not
                                           self._eval else 2)
719
        # rest to be printed
720 721 722 723 724
        sys_idx = sys % sys_per_page
        rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page
        # lower histo only
        is_lower = evaluation or not self._eval
        if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print):
725
            axis.set_xlabel(self._x_label)
726
        dflt_title = "Eval. scores" if evaluation else "Dev. scores"
727 728
        if self.n_systems == 1 and (not self._eval or self._hide_dev):
            dflt_title = " "
729 730
        add = self.n_systems if is_lower else 0
        axis.set_title(self._get_title(sys + add, dflt_title))
731
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
732
            '' if self._criterion is None else
733 734
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
735
        if draw_line:
736
            self._lines(threshold, label, neg, pos, idx)
737

738

739 740
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
741 742
        if self._step_print == sub_plot_idx or (is_lower and sys ==
                                                self.n_systems - 1):
743 744
            # print legend on the page
            self.plot_legends()
745
            mpl.tight_layout()
746
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
747 748
            mpl.clf()
            mpl.figure()
749

750
    def _get_title(self, idx, dflt=None):
751
        ''' Get the histo title for the given idx'''
752 753
        title = self._titles[idx] if self._titles is not None \
            and idx < len(self._titles) else dflt
754
        title = title or self._title_base
755 756
        title = '' if title is not None and not title.replace(
            ' ', '') else title
757
        return title or ''
758

759 760
    def plot_legends(self):
        ''' Print legend on current page'''
761 762 763
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
764 765 766 767 768 769 770
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

771 772
        if self._disp_legend:
            mpl.gcf().legend(
773
                lines, labels, loc=self._legend_loc, fancybox=True,
774
                framealpha=0.5, ncol=self._nlegends,
775
                bbox_to_anchor=(0.55, 1.1),
776
            )
777

778
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
779
        ''' Get scores and threshod for the given system at index idx'''
780 781
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
782 783 784 785 786 787 788 789 790 791
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
792
        # can have several files for one system
793 794
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
795 796
        eval_neg = eval_pos = None
        if self._eval:
797 798
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
799

800
        threshold = utils.get_thres(
801
            self._criterion, dev_neg[0], dev_pos[0]
802
        ) if self._thres is None else self._thres[idx]
803
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
804

805
    def _density_hist(self, scores, n, **kwargs):
806
        ''' Plots one density histo'''
807
        n, bins, patches = mpl.hist(
808
            scores, density=True,
809
            bins=self._nbins[n],
810
            **kwargs
811 812 813
        )
        return (n, bins, patches)

814 815
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
816
        ''' Plots vertical line at threshold '''
817
        label = label or 'Threshold'
818 819 820 821
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
822
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
823 824

    def _setup_hist(self, neg, pos):
825 826 827 828 829
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
830
        self._density_hist(
831 832
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
833 834
        )
        self._density_hist(
835 836
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
837
        )