figure.py 35.1 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
import numpy
9 10 11 12 13
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import (far_threshold, plot, utils, ppndf)
15 16 17
import logging

LOGGER = logging.getLogger("bob.measure")
18

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
19

20 21 22 23 24 25 26 27 28 29 30 31
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


32 33 34 35 36 37 38 39 40 41
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
42 43
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

44
    def __init__(self, ctx, scores, evaluation, func_load):
45 46 47 48 49 50 51
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
52 53 54 55
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
56 57 58 59 60
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
61
        self._legends = ctx.meta.get('legends')
62
        self._eval = evaluation
63
        self._min_arg = ctx.meta.get('min_arg', 1)
64 65 66 67 68
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
69 70
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
71
                                     "number of systems")
72 73 74 75 76 77 78 79 80 81

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
82
        # init matplotlib, log files, ...
83
        self.init_process()
84 85
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
86 87
        # Note that more than one dev or eval scores score can be passed to
        # each system
88
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
89
            # load scores for each system: get the corresponding arrays and
90
            # base-name of files
91
            input_scores, input_names = self._load_files(
92 93 94 95 96 97
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
98
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
99 100
            )
            self.compute(idx, input_scores, input_names)
101
        # setup final configuration, plotting properties, ...
102 103
        self.end_process()

104
    # protected functions that need to be overwritten
105 106
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
107
        before iterating through the different systems.
108 109 110
        Should reimplemented in derived classes"""
        pass

111
    # Main computations are done here in the subclasses
112
    @abstractmethod
113
    def compute(self, idx, input_scores, input_names):
114
        """Compute metrics or plots from the given scores provided by
115 116 117 118 119 120 121
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
122 123 124 125
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
126 127
        """
        pass
128 129 130 131 132 133 134
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

135
    # Things to do after the main iterative computations are done
136 137
    @abstractmethod
    def end_process(self):
138
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
139
        after iterating through the different systems.
140
        Should reimplemented in derived classes"""
141 142
        pass

143
    # common protected functions
144

145 146
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
147 148 149

        Returns
        -------
150 151 152 153 154
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
155
        '''
156 157 158 159 160 161
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
162

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
163

164 165 166 167 168 169 170 171
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
172

173
    def __init__(self, ctx, scores, evaluation, func_load,
174
                 names=('False Positive Rate', 'False Negative Rate',
175
                        'Precision', 'Recall', 'F1-score')):
176
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
177
        self.names = names
178 179 180 181
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
182
        self._decimal = ctx.meta.get('decimal', 2)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
183
        if self._thres is not None:
184
            if len(self._thres) == 1:
185 186
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
187
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
188
                    '#thresholds must be the same as #systems (%d)'
189
                    % len(self.n_systems)
190
                )
191 192
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
193 194 195 196
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

197 198 199
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

200
    def _numbers(self, neg, pos, threshold, fta):
201 202
        from .. import (farfrr, precision_recall, f_score)
        # fpr and fnr
203
        fmr, fnmr = farfrr(neg, pos, threshold)
204
        hter = (fmr + fnmr) / 2.0
205 206 207 208 209 210 211 212
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects

213 214 215 216 217 218 219 220 221
        # precision and recall
        precision, recall = precision_recall(neg, pos, threshold)

        # f_score
        f1_score = f_score(neg, pos, threshold, 1)
        return (fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc, precision,
                recall, f1_score)

    def _strings(self, metrics):
222 223 224 225 226 227 228 229 230 231 232
        fta_str = "%s%%" % format(100 * metrics[0], '.%df' % self._decimal)
        fmr_str = "%s%% (%d/%d)" % (format(100 * metrics[1], '.%df' % self._decimal),
                                    metrics[6], metrics[7])
        fnmr_str = "%s%% (%d/%d)" % (format(100 * metrics[2], '.%df' % self._decimal),
                                     metrics[8], metrics[9])
        far_str = "%s%%" % format(100 * metrics[4], '.%df' % self._decimal)
        frr_str = "%s%%" % format(100 * metrics[5], '.%df' % self._decimal)
        hter_str = "%s%%" % format(100 * metrics[3], '.%df' % self._decimal)
        prec_str = "%s" % format(metrics[10], '.%df' % self._decimal)
        recall_str = "%s" % format(metrics[11], '.%df' % self._decimal)
        f1_str = "%s" % format(metrics[12], '.%df' % self._decimal)
233 234 235 236 237 238

        return (fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str,
                prec_str, recall_str, f1_str)

    def _get_all_metrics(self, idx, input_scores, input_names):
        ''' Compute all metrics for dev and eval scores'''
239 240 241 242 243 244
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]

245
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
246
            if self._thres is None else self._thres[idx]
247

248
        title = self._legends[idx] if self._legends is not None else None
249
        if self._thres is None:
250
            far_str = ''
251
            if self._criterion == 'far' and self._far is not None:
252
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
253
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
254 255 256
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
257 258
                       file=self.log_file)
        else:
259
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
260
                       "Development set `%s`: %e"
261
                       % (dev_file or title, threshold), file=self.log_file)
262

263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
        res = []
        res.append(self._strings(self._numbers(
            dev_neg, dev_pos, threshold, dev_fta)))

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            res.append(self._strings(self._numbers(
                eval_neg, eval_pos, threshold, eval_fta)))
        else:
            res.append(None)

        return res

    def compute(self, idx, input_scores, input_names):
        ''' Compute metrics thresholds and tables (FPR, FNR, precision, recall,
        f1_score) for given system inputs'''
        dev_file = input_names[0]
        title = self._legends[idx] if self._legends is not None else None
        all_metrics = self._get_all_metrics(idx, input_scores, input_names)
        fta_dev = float(all_metrics[0][0].replace('%', ''))
        if fta_dev > 0.0:
            LOGGER.warn("NaNs scores (%s) were found in %s", all_metrics[0][0],
                        dev_file)
        headers = [' ' or title, 'Development']
        rows = [[self.names[0], all_metrics[0][1]],
                [self.names[1], all_metrics[0][2]],
                [self.names[2], all_metrics[0][6]],
                [self.names[3], all_metrics[0][7]],
                [self.names[4], all_metrics[0][8]]]
293

294
        if self._eval:
295 296 297 298 299
            eval_file = input_names[1]
            fta_eval = float(all_metrics[1][0].replace('%', ''))
            if fta_eval > 0.0:
                LOGGER.warn("NaNs scores (%s) were found in %s",
                            all_metrics[1][0], eval_file)
300 301
            # computes statistics for the eval set based on the threshold a
            # priori
302 303 304 305 306 307
            headers.append('Evaluation')
            rows[0].append(all_metrics[1][1])
            rows[1].append(all_metrics[1][2])
            rows[2].append(all_metrics[1][6])
            rows[3].append(all_metrics[1][7])
            rows[4].append(all_metrics[1][8])
308

309
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
310 311 312 313 314 315

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
316

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
class MultiMetrics(Metrics):
    '''Computes average of metrics based on several protocols (cross
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
    '''

    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('NaNs Rate', 'False Positive Rate',
                        'False Negative Rate', 'False Accept Rate',
                        'False Reject Rate', 'Half Total Error Rate')):
        super(MultiMetrics, self).__init__(
            ctx, scores, evaluation, func_load, names=names)

        self.headers = ['Methods'] + list(self.names)
        if self._eval:
            self.headers.insert(1, self.names[5] + ' (dev)')
        self.rows = []

    def _strings(self, metrics):
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
        fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
        fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
        fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
        far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
        frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
        hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
        '''Computes the average of metrics over several protocols.'''
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
            threshold = self.get_thres(self._criterion, neg, pos, self._far) \
                if self._thres is None else self._thres[idx]
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
                self._eval_metrics.append(
                    self._numbers(neg, pos, threshold, fta))
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(self._dev_metrics)

        if self._eval:
            self.rows.append([title, hter_str])
        else:
            self.rows.append([title, fta_str, fmr_str, fnmr_str,
                              far_str, frr_str, hter_str])

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(self._eval_metrics)

            self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
                                  far_str, frr_str, hter_str])

    def end_process(self):
        click.echo(tabulate(self.rows, self.headers,
                            self._tablefmt), file=self.log_file)
        super(MultiMetrics, self).end_process()


402 403 404 405
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
406

407 408
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
409 410 411 412 413 414
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
415 416 417
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
418
        elif self._axlim is not None and self._axlim[0] is not None:
419 420
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
421 422
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
423 424 425 426
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
427 428
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
429 430
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
431
        self._nb_figs = 2 if self._eval and self._split else 1
432
        self._colors = utils.get_colors(self.n_systems)
433
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
434 435
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
436
        self._titles = ctx.meta.get('titles', []) * 2
437 438 439 440 441
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

442 443
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
444 445 446 447 448 449 450 451 452 453
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
454
            self._ctx.meta else PdfPages(self._output)
455

456
        for i in range(self._nb_figs):
457
            fs = self._ctx.meta.get('figsize')
458 459
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
460
            fig.clear()
461 462

    def end_process(self):
463 464
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
465
        # draw vertical lines
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
485
        # only for plots
486 487 488
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
489 490
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
491 492 493
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
494 495
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
496
                self._set_axis()
497 498 499
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
500
        # do not want to close PDF when running evaluate
501 502 503 504
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
505
    # common protected functions
506 507

    def _label(self, base, name, idx):
508 509
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
510
        if self.n_systems > 1:
511 512 513
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

514
    def _set_axis(self):
515
        if self._axlim is not None:
516
            mpl.axis(self._axlim)
517

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
518

519
class Roc(PlotBase):
520
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
521

522 523
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
524
        self._titles = self._titles or ['ROC dev.', 'ROC eval.']
525 526
        self._x_label = self._x_label or 'FPR'
        self._y_label = self._y_label or "1 - FNR"
527 528 529
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
530
        # custom defaults
531
        if self._axlim is None:
532
            self._axlim = [None, None, -0.05, 1.05]
533

534
    def compute(self, idx, input_scores, input_names):
535 536
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
537 538
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
539 540
        dev_file = input_names[0]
        if self._eval:
541
            eval_neg, eval_pos = neg_list[1], pos_list[1]
542 543
            eval_file = input_names[1]

544
        mpl.figure(1)
545
        if self._eval:
546 547
            plot.roc_for_far(
                dev_neg, dev_pos,
548
                far_values=plot.log_values(self._min_dig or -4),
549
                CAR=self._semilogx,
550
                color=self._colors[idx], linestyle=self._linestyles[idx],
551
                label=self._label('dev', dev_file, idx)
552 553 554 555
            )
            if self._split:
                mpl.figure(2)

556
            linestyle = '--' if not self._split else self._linestyles[idx]
557
            plot.roc_for_far(
558 559
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
560
                CAR=self._semilogx,
561
                color=self._colors[idx],
562
                label=self._label('eval.', eval_file, idx)
563
            )
564
            if self._far_at is not None:
565
                from .. import farfrr
566
                for line in self._far_at:
567
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
568 569
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
570
                    eval_fnmr = 1 - eval_fnmr
571 572
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
573
        else:
574 575
            plot.roc_for_far(
                dev_neg, dev_pos,
576
                far_values=plot.log_values(self._min_dig or -4),
577
                CAR=self._semilogx,
578
                color=self._colors[idx], linestyle=self._linestyles[idx],
579
                label=self._label('dev', dev_file, idx)
580 581
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
582

583 584
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
585

586 587
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
588
        self._titles = self._titles or ['DET dev.', 'DET eval.']
589 590
        self._x_label = self._x_label or 'FPR (%)'
        self._y_label = self._y_label or 'FNR (%)'
591
        self._legend_loc = self._legend_loc or 'upper right'
592 593
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
594
        # custom defaults here
595 596
        if self._x_rotation is None:
            self._x_rotation = 50
597

598 599 600 601 602 603
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

604
    def compute(self, idx, input_scores, input_names):
605 606
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
607 608
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
609 610
        dev_file = input_names[0]
        if self._eval:
611
            eval_neg, eval_pos = neg_list[1], pos_list[1]
612 613
            eval_file = input_names[1]

614
        mpl.figure(1)
615
        if self._eval and eval_neg is not None:
616 617
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
618
                linestyle=self._linestyles[idx],
619
                label=self._label('dev.', dev_file, idx)
620 621 622
            )
            if self._split:
                mpl.figure(2)
623
            linestyle = '--' if not self._split else self._linestyles[idx]
624
            plot.det(
625
                eval_neg, eval_pos, self._points, color=self._colors[idx],
626
                linestyle=linestyle,
627
                label=self._label('eval.', eval_file, idx)
628
            )
629 630 631 632
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
633 634
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
635 636 637
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
638 639 640
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
641
                linestyle=self._linestyles[idx],
642
                label=self._label('dev.', dev_file, idx)
643 644
            )

645
    def _set_axis(self):
646
        plot.det_axis(self._axlim)
647

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
648

649 650
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
651

652
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
653
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
654
        if self._min_arg != 2:
655
            raise click.UsageError("EPC requires dev. and eval. score files")
656
        self._titles = self._titles or ['EPC'] * 2
657
        self._x_label = self._x_label or r'$\alpha$'
658
        self._y_label = self._y_label or hter + ' (%)'
659
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
660
        self._eval = True  # always eval data with EPC
661
        self._split = False
662
        self._nb_figs = 1
663
        self._far_at = None
664

665
    def compute(self, idx, input_scores, input_names):
666
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
667 668
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
669 670
        dev_file = input_names[0]
        if self._eval:
671
            eval_neg, eval_pos = neg_list[1], pos_list[1]
672 673
            eval_file = input_names[1]

674
        plot.epc(
675
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
676
            color=self._colors[idx], linestyle=self._linestyles[idx],
677
            label=self._label(
678
                'curve', dev_file + "_" + eval_file, idx
679
            )
680 681
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
682

683
class Hist(PlotBase):
684
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
685

686
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
687
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
688 689 690
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
691
            self._nbins, nhist_per_system, 'n_bins',
692
            'histograms')
693
        self._thres = ctx.meta.get('thres')
694 695
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
696
        self._criterion = ctx.meta.get('criterion')
697
        # no vertical (threshold) is displayed
698
        self._no_line = ctx.meta.get('no_line', False)
699
        # subplot grid
700 701
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
702
        # do not display dev histo
703
        self._hide_dev = ctx.meta.get('hide_dev', False)
704
        if self._hide_dev and not self._eval:
705 706
            raise click.BadParameter(
                "You can only use --hide-dev along with --eval")
707
        # dev hist are displayed next to eval hist
708
        self._nrows *= 1 if self._hide_dev or not self._eval else 2
709
        self._nlegends = ctx.meta.get('legends_ncol', 3)
710
        self._legend_loc = self._legend_loc or 'upper center'
711
        # number of subplot on one page
712
        self._step_print = int(self._nrows * self._ncols)
713
        self._title_base = 'Scores'
714 715
        self._y_label = self._y_label or 'Probability density'
        self._x_label = self._x_label or 'Score values'
716
        self._end_setup_plot = False
717
        # overide _titles of PlotBase
718
        self._titles = ctx.meta.get('titles', []) * 2
719

720
    def compute(self, idx, input_scores, input_names):
721
        ''' Draw histograms of negative and positive scores.'''
722
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
723
            self._get_neg_pos_thres(idx, input_scores, input_names)
724 725 726 727 728 729 730 731
        # keep id of the current system
        sys = idx
        # if the id of the current system does not match the id of the plot, 
        # change it
        if not self._hide_dev and self._eval:
            row = int(idx / self._ncols) * 2
            col = idx % self._ncols
            idx = col + self._ncols * row
732 733

        if not self._hide_dev or not self._eval:
734
            self._print_subplot(idx, sys, dev_neg, dev_pos, threshold,
735
                                not self._no_line, False)
736 737

        if self._eval:
738 739
            idx += self._ncols if not self._hide_dev else 0
            self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
740
                                not self._no_line, True)
741

742
    def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation):
743
        ''' print a subplot for the given score and subplot index'''
744 745 746 747 748 749 750
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
751 752 753
        # systems per page
        sys_per_page = self._step_print / (1 if self._hide_dev or not
                                           self._eval else 2)
754
        # rest to be printed
755 756 757 758 759
        sys_idx = sys % sys_per_page
        rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page
        # lower histo only
        is_lower = evaluation or not self._eval
        if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print):
760
            axis.set_xlabel(self._x_label)
761
        dflt_title = "Eval. scores" if evaluation else "Dev. scores"
762 763
        if self.n_systems == 1 and (not self._eval or self._hide_dev):
            dflt_title = " "
764 765
        add = self.n_systems if is_lower else 0
        axis.set_title(self._get_title(sys + add, dflt_title))
766
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
767
            '' if self._criterion is None else
768 769
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
770
        if draw_line:
771
            self._lines(threshold, label, neg, pos, idx)
772

773

774 775
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
776 777
        if self._step_print == sub_plot_idx or (is_lower and sys ==
                                                self.n_systems - 1):
778 779
            # print legend on the page
            self.plot_legends()
780
            mpl.tight_layout()
781
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
782 783
            mpl.clf()
            mpl.figure()
784

785
    def _get_title(self, idx, dflt=None):
786
        ''' Get the histo title for the given idx'''
787 788
        title = self._titles[idx] if self._titles is not None \
            and idx < len(self._titles) else dflt
789
        title = title or self._title_base
790 791
        title = '' if title is not None and not title.replace(
            ' ', '') else title
792
        return title or ''
793

794 795
    def plot_legends(self):
        ''' Print legend on current page'''
796 797 798
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
799 800 801 802 803 804 805
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

806 807
        if self._disp_legend:
            mpl.gcf().legend(
808
                lines, labels, loc=self._legend_loc, fancybox=True,
809
                framealpha=0.5, ncol=self._nlegends,
810
                bbox_to_anchor=(0.55, 1.1),
811
            )
812

813
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
814
        ''' Get scores and threshod for the given system at index idx'''
815 816
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
817 818 819 820 821 822 823 824 825 826
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
827
        # can have several files for one system
828 829
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
830 831
        eval_neg = eval_pos = None
        if self._eval:
832 833
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
834

835
        threshold = utils.get_thres(
836
            self._criterion, dev_neg[0], dev_pos[0]
837
        ) if self._thres is None else self._thres[idx]
838
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
839

840
    def _density_hist(self, scores, n, **kwargs):
841
        ''' Plots one density histo'''
842
        n, bins, patches = mpl.hist(
843
            scores, density=True,
844
            bins=self._nbins[n],
845
            **kwargs
846 847 848
        )
        return (n, bins, patches)

849 850
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
851
        ''' Plots vertical line at threshold '''
852
        label = label or 'Threshold'
853 854 855 856
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
857
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
858 859

    def _setup_hist(self, neg, pos):
860 861 862 863 864
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
865
        self._density_hist(
866 867
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
868 869
        )
        self._density_hist(
870 871
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
872
        )