figure.py 35.8 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
import numpy
9 10 11 12 13
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import (far_threshold, plot, utils, ppndf)
15 16 17
import logging

LOGGER = logging.getLogger("bob.measure")
18

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
19

20 21 22 23 24 25 26 27 28 29 30 31
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


32 33 34 35 36 37 38 39 40 41
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
42 43
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

44
    def __init__(self, ctx, scores, evaluation, func_load):
45 46 47 48 49 50 51
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
52 53 54 55
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
56 57 58 59 60
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
61
        self._legends = ctx.meta.get('legends')
62
        self._eval = evaluation
63
        self._min_arg = ctx.meta.get('min_arg', 1)
64 65 66 67 68
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
69 70
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
71
                                     "number of systems")
72 73 74 75 76 77 78 79 80 81

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
82
        # init matplotlib, log files, ...
83
        self.init_process()
84 85
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
86 87
        # Note that more than one dev or eval scores score can be passed to
        # each system
88
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
89
            # load scores for each system: get the corresponding arrays and
90
            # base-name of files
91
            input_scores, input_names = self._load_files(
92 93 94 95 96 97
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
98
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
99
            )
100 101 102 103 104 105 106 107 108 109 110
            LOGGER.info("-----Input files for system %d-----", idx + 1)
            for i, name in enumerate(input_names):
                if not self._eval:
                    LOGGER.info("Dev. score %d: %s", i + 1, name)
                else:
                    if i % 2 == 0:
                        LOGGER.info("Dev. score %d: %s", i / 2 + 1, name)
                    else:
                        LOGGER.info("Eval. score %d: %s", i / 2 + 1, name)
            LOGGER.info("----------------------------------")

111
            self.compute(idx, input_scores, input_names)
112
        # setup final configuration, plotting properties, ...
113 114
        self.end_process()

115
    # protected functions that need to be overwritten
116 117
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
118
        before iterating through the different systems.
119 120 121
        Should reimplemented in derived classes"""
        pass

122
    # Main computations are done here in the subclasses
123
    @abstractmethod
124
    def compute(self, idx, input_scores, input_names):
125
        """Compute metrics or plots from the given scores provided by
126 127 128 129 130 131 132
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
133 134 135 136
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
137 138
        """
        pass
139 140 141 142 143 144 145
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

146
    # Things to do after the main iterative computations are done
147 148
    @abstractmethod
    def end_process(self):
149
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
150
        after iterating through the different systems.
151
        Should reimplemented in derived classes"""
152 153
        pass

154
    # common protected functions
155

156 157
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
158 159 160

        Returns
        -------
161 162 163 164
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
165
                A list of the given files
166
        '''
167 168 169
        scores = []
        basenames = []
        for filename in filepaths:
170
            basenames.append(filename.split(".")[0])
171 172
            scores.append(self.func_load(filename))
        return scores, basenames
173

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
174

175 176 177 178 179 180 181 182
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
183

184
    def __init__(self, ctx, scores, evaluation, func_load,
185
                 names=('False Positive Rate', 'False Negative Rate',
186
                        'Precision', 'Recall', 'F1-score')):
187
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
188
        self.names = names
189 190 191 192
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
193
        self._decimal = ctx.meta.get('decimal', 2)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
194
        if self._thres is not None:
195
            if len(self._thres) == 1:
196 197
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
198
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
199
                    '#thresholds must be the same as #systems (%d)'
200
                    % len(self.n_systems)
201
                )
202 203
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
204 205 206 207
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

208 209 210
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

211
    def _numbers(self, neg, pos, threshold, fta):
212 213
        from .. import (farfrr, precision_recall, f_score)
        # fpr and fnr
214
        fmr, fnmr = farfrr(neg, pos, threshold)
215
        hter = (fmr + fnmr) / 2.0
216 217 218 219 220 221 222 223
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects

224 225 226 227 228 229 230 231 232
        # precision and recall
        precision, recall = precision_recall(neg, pos, threshold)

        # f_score
        f1_score = f_score(neg, pos, threshold, 1)
        return (fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc, precision,
                recall, f1_score)

    def _strings(self, metrics):
233 234 235 236 237 238 239 240 241 242 243
        fta_str = "%s%%" % format(100 * metrics[0], '.%df' % self._decimal)
        fmr_str = "%s%% (%d/%d)" % (format(100 * metrics[1], '.%df' % self._decimal),
                                    metrics[6], metrics[7])
        fnmr_str = "%s%% (%d/%d)" % (format(100 * metrics[2], '.%df' % self._decimal),
                                     metrics[8], metrics[9])
        far_str = "%s%%" % format(100 * metrics[4], '.%df' % self._decimal)
        frr_str = "%s%%" % format(100 * metrics[5], '.%df' % self._decimal)
        hter_str = "%s%%" % format(100 * metrics[3], '.%df' % self._decimal)
        prec_str = "%s" % format(metrics[10], '.%df' % self._decimal)
        recall_str = "%s" % format(metrics[11], '.%df' % self._decimal)
        f1_str = "%s" % format(metrics[12], '.%df' % self._decimal)
244 245 246 247 248 249

        return (fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str,
                prec_str, recall_str, f1_str)

    def _get_all_metrics(self, idx, input_scores, input_names):
        ''' Compute all metrics for dev and eval scores'''
250 251 252 253 254 255
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]

256
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
257
            if self._thres is None else self._thres[idx]
258

259
        title = self._legends[idx] if self._legends is not None else None
260
        if self._thres is None:
261
            far_str = ''
262
            if self._criterion == 'far' and self._far is not None:
263
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
264
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
265 266 267
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
268 269
                       file=self.log_file)
        else:
270
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
271
                       "Development set `%s`: %e"
272
                       % (dev_file or title, threshold), file=self.log_file)
273

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
        res = []
        res.append(self._strings(self._numbers(
            dev_neg, dev_pos, threshold, dev_fta)))

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            res.append(self._strings(self._numbers(
                eval_neg, eval_pos, threshold, eval_fta)))
        else:
            res.append(None)

        return res

    def compute(self, idx, input_scores, input_names):
        ''' Compute metrics thresholds and tables (FPR, FNR, precision, recall,
        f1_score) for given system inputs'''
        dev_file = input_names[0]
        title = self._legends[idx] if self._legends is not None else None
        all_metrics = self._get_all_metrics(idx, input_scores, input_names)
        fta_dev = float(all_metrics[0][0].replace('%', ''))
        if fta_dev > 0.0:
296
            LOGGER.error("NaNs scores (%s) were found in %s", all_metrics[0][0],
297 298 299 300 301 302 303
                        dev_file)
        headers = [' ' or title, 'Development']
        rows = [[self.names[0], all_metrics[0][1]],
                [self.names[1], all_metrics[0][2]],
                [self.names[2], all_metrics[0][6]],
                [self.names[3], all_metrics[0][7]],
                [self.names[4], all_metrics[0][8]]]
304

305
        if self._eval:
306 307 308
            eval_file = input_names[1]
            fta_eval = float(all_metrics[1][0].replace('%', ''))
            if fta_eval > 0.0:
309
                LOGGER.error("NaNs scores (%s) were found in %s",
310
                            all_metrics[1][0], eval_file)
311 312
            # computes statistics for the eval set based on the threshold a
            # priori
313 314 315 316 317 318
            headers.append('Evaluation')
            rows[0].append(all_metrics[1][1])
            rows[1].append(all_metrics[1][2])
            rows[2].append(all_metrics[1][6])
            rows[3].append(all_metrics[1][7])
            rows[4].append(all_metrics[1][8])
319

320
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
321 322 323 324 325 326

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
327

328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
class MultiMetrics(Metrics):
    '''Computes average of metrics based on several protocols (cross
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
    '''

    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('NaNs Rate', 'False Positive Rate',
                        'False Negative Rate', 'False Accept Rate',
                        'False Reject Rate', 'Half Total Error Rate')):
        super(MultiMetrics, self).__init__(
            ctx, scores, evaluation, func_load, names=names)

        self.headers = ['Methods'] + list(self.names)
        if self._eval:
            self.headers.insert(1, self.names[5] + ' (dev)')
        self.rows = []

    def _strings(self, metrics):
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
        fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
        fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
        fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
        far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
        frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
        hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
        '''Computes the average of metrics over several protocols.'''
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
            threshold = self.get_thres(self._criterion, neg, pos, self._far) \
                if self._thres is None else self._thres[idx]
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
                self._eval_metrics.append(
                    self._numbers(neg, pos, threshold, fta))
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(self._dev_metrics)

        if self._eval:
            self.rows.append([title, hter_str])
        else:
            self.rows.append([title, fta_str, fmr_str, fnmr_str,
                              far_str, frr_str, hter_str])

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(self._eval_metrics)

            self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
                                  far_str, frr_str, hter_str])

    def end_process(self):
        click.echo(tabulate(self.rows, self.headers,
                            self._tablefmt), file=self.log_file)
        super(MultiMetrics, self).end_process()


413 414 415 416
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
417

418 419
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
420 421 422 423 424 425
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
426 427 428
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
429
        elif self._axlim is not None and self._axlim[0] is not None:
430 431
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
432 433
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
434 435 436 437
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
438 439
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
440 441
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
442
        self._nb_figs = 2 if self._eval and self._split else 1
443
        self._colors = utils.get_colors(self.n_systems)
444
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
445 446
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
447
        self._titles = ctx.meta.get('titles', []) * 2
448 449 450 451 452
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

453 454
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
455 456 457 458 459 460 461 462 463 464
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
465
            self._ctx.meta else PdfPages(self._output)
466

467
        for i in range(self._nb_figs):
468
            fs = self._ctx.meta.get('figsize')
469 470
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
471
            fig.clear()
472 473

    def end_process(self):
474 475
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
476
        # draw vertical lines
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
496
        # only for plots
497 498 499
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
500 501
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
502 503 504
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
505 506
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
507
                self._set_axis()
508 509 510
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
511
        # do not want to close PDF when running evaluate
512 513 514 515
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
516
    # common protected functions
517

518
    def _label(self, base, idx):
519 520
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
521
        if self.n_systems > 1:
522 523
            return base + (" %d" % (idx + 1))
        return base
524

525
    def _set_axis(self):
526
        if self._axlim is not None:
527
            mpl.axis(self._axlim)
528

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
529

530
class Roc(PlotBase):
531
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
532

533 534
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
535
        self._titles = self._titles or ['ROC dev.', 'ROC eval.']
536 537
        self._x_label = self._x_label or 'FPR'
        self._y_label = self._y_label or "1 - FNR"
538 539 540
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
541
        # custom defaults
542
        if self._axlim is None:
543
            self._axlim = [None, None, -0.05, 1.05]
544

545
    def compute(self, idx, input_scores, input_names):
546 547
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
548 549
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
550 551
        dev_file = input_names[0]
        if self._eval:
552
            eval_neg, eval_pos = neg_list[1], pos_list[1]
553 554
            eval_file = input_names[1]

555
        mpl.figure(1)
556
        if self._eval:
557
            LOGGER.info("ROC dev. curve using %s", dev_file)
558 559
            plot.roc_for_far(
                dev_neg, dev_pos,
560
                far_values=plot.log_values(self._min_dig or -4),
561
                CAR=self._semilogx,
562
                color=self._colors[idx], linestyle=self._linestyles[idx],
563
                label=self._label('dev', idx)
564 565 566 567
            )
            if self._split:
                mpl.figure(2)

568
            linestyle = '--' if not self._split else self._linestyles[idx]
569
            LOGGER.info("ROC eval. curve using %s", eval_file)
570
            plot.roc_for_far(
571 572
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
573
                CAR=self._semilogx,
574
                color=self._colors[idx],
575
                label=self._label('eval.', idx)
576
            )
577
            if self._far_at is not None:
578
                from .. import farfrr
579
                for line in self._far_at:
580
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
581 582
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
583
                    eval_fnmr = 1 - eval_fnmr
584 585
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
586
        else:
587
            LOGGER.info("ROC dev. curve using %s", dev_file)
588 589
            plot.roc_for_far(
                dev_neg, dev_pos,
590
                far_values=plot.log_values(self._min_dig or -4),
591
                CAR=self._semilogx,
592
                color=self._colors[idx], linestyle=self._linestyles[idx],
593
                label=self._label('dev', idx)
594 595
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
596

597 598
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
599

600 601
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
602
        self._titles = self._titles or ['DET dev.', 'DET eval.']
603 604
        self._x_label = self._x_label or 'FPR (%)'
        self._y_label = self._y_label or 'FNR (%)'
605
        self._legend_loc = self._legend_loc or 'upper right'
606 607
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
608
        # custom defaults here
609 610
        if self._x_rotation is None:
            self._x_rotation = 50
611

612 613 614 615 616 617
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

618
    def compute(self, idx, input_scores, input_names):
619 620
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
621 622
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
623 624
        dev_file = input_names[0]
        if self._eval:
625
            eval_neg, eval_pos = neg_list[1], pos_list[1]
626 627
            eval_file = input_names[1]

628
        mpl.figure(1)
629
        if self._eval and eval_neg is not None:
630
            LOGGER.info("DET dev. curve using %s", dev_file)
631 632
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
633
                linestyle=self._linestyles[idx],
634
                label=self._label('dev.', idx)
635 636 637
            )
            if self._split:
                mpl.figure(2)
638
            linestyle = '--' if not self._split else self._linestyles[idx]
639
            LOGGER.info("DET eval. curve using %s", eval_file)
640
            plot.det(
641
                eval_neg, eval_pos, self._points, color=self._colors[idx],
642
                linestyle=linestyle,
643
                label=self._label('eval.', idx)
644
            )
645 646 647 648
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
649 650
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
651 652 653
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
654
        else:
655
            LOGGER.info("DET dev. curve using %s", dev_file)
656 657
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
658
                linestyle=self._linestyles[idx],
659
                label=self._label('dev.', idx)
660 661
            )

662
    def _set_axis(self):
663
        plot.det_axis(self._axlim)
664

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
665

666 667
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
668

669
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
670
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
671
        if self._min_arg != 2:
672
            raise click.UsageError("EPC requires dev. and eval. score files")
673
        self._titles = self._titles or ['EPC'] * 2
674
        self._x_label = self._x_label or r'$\alpha$'
675
        self._y_label = self._y_label or hter + ' (%)'
676
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
677
        self._eval = True  # always eval data with EPC
678
        self._split = False
679
        self._nb_figs = 1
680
        self._far_at = None
681

682
    def compute(self, idx, input_scores, input_names):
683
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
684 685
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
686 687
        dev_file = input_names[0]
        if self._eval:
688
            eval_neg, eval_pos = neg_list[1], pos_list[1]
689 690
            eval_file = input_names[1]

691
        LOGGER.info("EPC using %s", dev_file + "_" + eval_file)
692
        plot.epc(
693
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
694
            color=self._colors[idx], linestyle=self._linestyles[idx],
695
            label=self._label(
696
                'curve', idx
697
            )
698 699
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
700

701
class Hist(PlotBase):
702
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
703

704
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
705
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
706 707 708
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
709
            self._nbins, nhist_per_system, 'n_bins',
710
            'histograms')
711
        self._thres = ctx.meta.get('thres')
712 713
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
714
        self._criterion = ctx.meta.get('criterion')
715
        # no vertical (threshold) is displayed
716
        self._no_line = ctx.meta.get('no_line', False)
717
        # subplot grid
718 719
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
720
        # do not display dev histo
721
        self._hide_dev = ctx.meta.get('hide_dev', False)
722
        if self._hide_dev and not self._eval:
723 724
            raise click.BadParameter(
                "You can only use --hide-dev along with --eval")
725
        # dev hist are displayed next to eval hist
726
        self._nrows *= 1 if self._hide_dev or not self._eval else 2
727
        self._nlegends = ctx.meta.get('legends_ncol', 3)
728
        self._legend_loc = self._legend_loc or 'upper center'
729
        # number of subplot on one page
730
        self._step_print = int(self._nrows * self._ncols)
731
        self._title_base = 'Scores'
732 733
        self._y_label = self._y_label or 'Probability density'
        self._x_label = self._x_label or 'Score values'
734
        self._end_setup_plot = False
735
        # overide _titles of PlotBase
736
        self._titles = ctx.meta.get('titles', []) * 2
737

738
    def compute(self, idx, input_scores, input_names):
739
        ''' Draw histograms of negative and positive scores.'''
740
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
741
            self._get_neg_pos_thres(idx, input_scores, input_names)
742 743 744 745 746 747 748 749
        # keep id of the current system
        sys = idx
        # if the id of the current system does not match the id of the plot, 
        # change it
        if not self._hide_dev and self._eval:
            row = int(idx / self._ncols) * 2
            col = idx % self._ncols
            idx = col + self._ncols * row
750 751

        if not self._hide_dev or not self._eval:
752
            self._print_subplot(idx, sys, dev_neg, dev_pos, threshold,
753
                                not self._no_line, False)
754 755

        if self._eval:
756 757
            idx += self._ncols if not self._hide_dev else 0
            self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
758
                                not self._no_line, True)
759

760
    def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation):
761
        ''' print a subplot for the given score and subplot index'''
762 763 764 765 766 767 768
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
769 770 771
        # systems per page
        sys_per_page = self._step_print / (1 if self._hide_dev or not
                                           self._eval else 2)
772
        # rest to be printed
773 774 775 776 777
        sys_idx = sys % sys_per_page
        rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page
        # lower histo only
        is_lower = evaluation or not self._eval
        if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print):
778
            axis.set_xlabel(self._x_label)
779
        dflt_title = "Eval. scores" if evaluation else "Dev. scores"
780 781
        if self.n_systems == 1 and (not self._eval or self._hide_dev):
            dflt_title = " "
782 783
        add = self.n_systems if is_lower else 0
        axis.set_title(self._get_title(sys + add, dflt_title))
784
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
785
            '' if self._criterion is None else
786 787
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
788
        if draw_line:
789
            self._lines(threshold, label, neg, pos, idx)
790

791

792 793
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
794 795
        if self._step_print == sub_plot_idx or (is_lower and sys ==
                                                self.n_systems - 1):
796 797
            # print legend on the page
            self.plot_legends()
798
            mpl.tight_layout()
799
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
800 801
            mpl.clf()
            mpl.figure()
802

803
    def _get_title(self, idx, dflt=None):
804
        ''' Get the histo title for the given idx'''
805 806
        title = self._titles[idx] if self._titles is not None \
            and idx < len(self._titles) else dflt
807
        title = title or self._title_base
808 809
        title = '' if title is not None and not title.replace(
            ' ', '') else title
810
        return title or ''
811

812 813
    def plot_legends(self):
        ''' Print legend on current page'''
814 815 816
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
817 818 819 820 821 822 823
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

824 825
        if self._disp_legend:
            mpl.gcf().legend(
826
                lines, labels, loc=self._legend_loc, fancybox=True,
827
                framealpha=0.5, ncol=self._nlegends,
828
                bbox_to_anchor=(0.55, 1.1),
829
            )
830

831
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
832
        ''' Get scores and threshod for the given system at index idx'''
833 834
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
835 836 837 838 839 840 841 842 843 844
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
845
        # can have several files for one system
846 847
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
848 849
        eval_neg = eval_pos = None
        if self._eval:
850 851
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
852

853
        threshold = utils.get_thres(
854
            self._criterion, dev_neg[0], dev_pos[0]
855
        ) if self._thres is None else self._thres[idx]
856
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
857

858
    def _density_hist(self, scores, n, **kwargs):
859
        ''' Plots one density histo'''
860
        n, bins, patches = mpl.hist(
861
            scores, density=True,
862
            bins=self._nbins[n],
863
            **kwargs
864 865 866
        )
        return (n, bins, patches)

867 868
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
869
        ''' Plots vertical line at threshold '''
870
        label = label or 'Threshold'
871 872 873 874
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
875
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
876 877

    def _setup_hist(self, neg, pos):
878 879 880 881 882
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
883
        self._density_hist(
884 885
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
886 887
        )
        self._density_hist(
888 889
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
890
        )