figure.py 35.3 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
import numpy
9 10 11 12 13
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import (far_threshold, plot, utils, ppndf)
15 16 17
import logging

LOGGER = logging.getLogger("bob.measure")
18

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
19

20 21 22 23 24 25 26 27 28 29 30 31
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


32 33 34 35 36 37 38 39 40 41
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
42 43
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

44
    def __init__(self, ctx, scores, evaluation, func_load):
45 46 47 48 49 50 51
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
52 53 54 55
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
56 57 58 59 60
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
61
        self._legends = ctx.meta.get('legends')
62
        self._eval = evaluation
63
        self._min_arg = ctx.meta.get('min_arg', 1)
64 65 66 67 68
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
69 70
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
71
                                     "number of systems")
72 73 74 75 76 77 78 79 80 81

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
82
        # init matplotlib, log files, ...
83
        self.init_process()
84 85
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
86 87
        # Note that more than one dev or eval scores score can be passed to
        # each system
88
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
89
            # load scores for each system: get the corresponding arrays and
90
            # base-name of files
91
            input_scores, input_names = self._load_files(
92 93 94 95 96 97
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
98
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
99 100
            )
            self.compute(idx, input_scores, input_names)
101
        # setup final configuration, plotting properties, ...
102 103
        self.end_process()

104
    # protected functions that need to be overwritten
105 106
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
107
        before iterating through the different systems.
108 109 110
        Should reimplemented in derived classes"""
        pass

111
    # Main computations are done here in the subclasses
112
    @abstractmethod
113
    def compute(self, idx, input_scores, input_names):
114
        """Compute metrics or plots from the given scores provided by
115 116 117 118 119 120 121
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
122 123 124 125
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
126 127
        """
        pass
128 129 130 131 132 133 134
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

135
    # Things to do after the main iterative computations are done
136 137
    @abstractmethod
    def end_process(self):
138
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
139
        after iterating through the different systems.
140
        Should reimplemented in derived classes"""
141 142
        pass

143
    # common protected functions
144

145 146
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
147 148 149

        Returns
        -------
150 151 152 153 154
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
155
        '''
156 157 158 159 160 161
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
162

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
163

164 165 166 167 168 169 170 171
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
172

173
    def __init__(self, ctx, scores, evaluation, func_load,
174 175
                 names=('False Positive Rate', 'False Negative Rate',
                        'F1-score', 'Precision', 'Recall')):
176
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
177
        self.names = names
178 179 180 181
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
182
        self._decimal = ctx.meta.get('decimal', 2)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
183
        if self._thres is not None:
184
            if len(self._thres) == 1:
185 186
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
187
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
188
                    '#thresholds must be the same as #systems (%d)'
189
                    % len(self.n_systems)
190
                )
191 192
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
193 194 195 196
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

197 198 199
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

200
    def _numbers(self, neg, pos, threshold, fta):
201 202
        from .. import (farfrr, precision_recall, f_score)
        # fpr and fnr
203
        fmr, fnmr = farfrr(neg, pos, threshold)
204
        hter = (fmr + fnmr) / 2.0
205 206 207 208 209 210 211 212
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects

213 214 215 216 217 218 219 220 221
        # precision and recall
        precision, recall = precision_recall(neg, pos, threshold)

        # f_score
        f1_score = f_score(neg, pos, threshold, 1)
        return (fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc, precision,
                recall, f1_score)

    def _strings(self, metrics):
222 223 224 225 226 227 228 229 230 231 232
        fta_str = "%s%%" % format(100 * metrics[0], '.%df' % self._decimal)
        fmr_str = "%s%% (%d/%d)" % (format(100 * metrics[1], '.%df' % self._decimal),
                                    metrics[6], metrics[7])
        fnmr_str = "%s%% (%d/%d)" % (format(100 * metrics[2], '.%df' % self._decimal),
                                     metrics[8], metrics[9])
        far_str = "%s%%" % format(100 * metrics[4], '.%df' % self._decimal)
        frr_str = "%s%%" % format(100 * metrics[5], '.%df' % self._decimal)
        hter_str = "%s%%" % format(100 * metrics[3], '.%df' % self._decimal)
        prec_str = "%s" % format(metrics[10], '.%df' % self._decimal)
        recall_str = "%s" % format(metrics[11], '.%df' % self._decimal)
        f1_str = "%s" % format(metrics[12], '.%df' % self._decimal)
233 234 235 236 237 238

        return (fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str,
                prec_str, recall_str, f1_str)

    def _get_all_metrics(self, idx, input_scores, input_names):
        ''' Compute all metrics for dev and eval scores'''
239 240 241 242 243 244
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]

245
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
246
            if self._thres is None else self._thres[idx]
247

248
        title = self._legends[idx] if self._legends is not None else None
249
        if self._thres is None:
250
            far_str = ''
251
            if self._criterion == 'far' and self._far is not None:
252
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
253
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
254 255 256
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
257 258
                       file=self.log_file)
        else:
259
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
260
                       "Development set `%s`: %e"
261
                       % (dev_file or title, threshold), file=self.log_file)
262

263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
        res = []
        res.append(self._strings(self._numbers(
            dev_neg, dev_pos, threshold, dev_fta)))

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            res.append(self._strings(self._numbers(
                eval_neg, eval_pos, threshold, eval_fta)))
        else:
            res.append(None)

        return res

    def compute(self, idx, input_scores, input_names):
        ''' Compute metrics thresholds and tables (FPR, FNR, precision, recall,
        f1_score) for given system inputs'''
        dev_file = input_names[0]
        title = self._legends[idx] if self._legends is not None else None
        all_metrics = self._get_all_metrics(idx, input_scores, input_names)
        fta_dev = float(all_metrics[0][0].replace('%', ''))
        if fta_dev > 0.0:
            click.echo("NaNs scores (%s) were found in %s" %
                       (all_metrics[0][0], dev_file))
            LOGGER.warn("NaNs scores (%s) were found in %s", all_metrics[0][0],
                        dev_file)
        headers = [' ' or title, 'Development']
        rows = [[self.names[0], all_metrics[0][1]],
                [self.names[1], all_metrics[0][2]],
                [self.names[2], all_metrics[0][6]],
                [self.names[3], all_metrics[0][7]],
                [self.names[4], all_metrics[0][8]]]
295

296
        if self._eval:
297 298 299 300 301 302 303
            eval_file = input_names[1]
            fta_eval = float(all_metrics[1][0].replace('%', ''))
            if fta_eval > 0.0:
                click.echo("NaNs scores (%s) were found in %s" %
                           (all_metrics[1][0], eval_file))
                LOGGER.warn("NaNs scores (%s) were found in %s",
                            all_metrics[1][0], eval_file)
304 305
            # computes statistics for the eval set based on the threshold a
            # priori
306 307 308 309 310 311
            headers.append('Evaluation')
            rows[0].append(all_metrics[1][1])
            rows[1].append(all_metrics[1][2])
            rows[2].append(all_metrics[1][6])
            rows[3].append(all_metrics[1][7])
            rows[4].append(all_metrics[1][8])
312

313
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
314 315 316 317 318 319

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
320

321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
class MultiMetrics(Metrics):
    '''Computes average of metrics based on several protocols (cross
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
    '''

    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('NaNs Rate', 'False Positive Rate',
                        'False Negative Rate', 'False Accept Rate',
                        'False Reject Rate', 'Half Total Error Rate')):
        super(MultiMetrics, self).__init__(
            ctx, scores, evaluation, func_load, names=names)

        self.headers = ['Methods'] + list(self.names)
        if self._eval:
            self.headers.insert(1, self.names[5] + ' (dev)')
        self.rows = []

    def _strings(self, metrics):
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
        fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
        fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
        fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
        far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
        frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
        hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
        '''Computes the average of metrics over several protocols.'''
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
            threshold = self.get_thres(self._criterion, neg, pos, self._far) \
                if self._thres is None else self._thres[idx]
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
                self._eval_metrics.append(
                    self._numbers(neg, pos, threshold, fta))
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(self._dev_metrics)

        if self._eval:
            self.rows.append([title, hter_str])
        else:
            self.rows.append([title, fta_str, fmr_str, fnmr_str,
                              far_str, frr_str, hter_str])

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(self._eval_metrics)

            self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
                                  far_str, frr_str, hter_str])

    def end_process(self):
        click.echo(tabulate(self.rows, self.headers,
                            self._tablefmt), file=self.log_file)
        super(MultiMetrics, self).end_process()


406 407 408 409
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
410

411 412
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
413 414 415 416 417 418
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
419 420 421
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
422
        elif self._axlim is not None and self._axlim[0] is not None:
423 424
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
425 426
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
427 428 429 430
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
431 432
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
433 434
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
435
        self._nb_figs = 2 if self._eval and self._split else 1
436
        self._colors = utils.get_colors(self.n_systems)
437
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
438 439
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
440
        self._titles = ctx.meta.get('titles', []) * 2
441 442 443 444 445
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

446 447
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
448 449 450 451 452 453 454 455 456 457
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
458
            self._ctx.meta else PdfPages(self._output)
459

460
        for i in range(self._nb_figs):
461
            fs = self._ctx.meta.get('figsize')
462 463
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
464
            fig.clear()
465 466

    def end_process(self):
467 468
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
469
        # draw vertical lines
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
489
        # only for plots
490 491 492
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
493 494
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
495 496 497
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
498 499
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
500
                self._set_axis()
501 502 503
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
504
        # do not want to close PDF when running evaluate
505 506 507 508
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
509
    # common protected functions
510 511

    def _label(self, base, name, idx):
512 513
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
514
        if self.n_systems > 1:
515 516 517
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

518
    def _set_axis(self):
519
        if self._axlim is not None:
520
            mpl.axis(self._axlim)
521

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
522

523
class Roc(PlotBase):
524
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
525

526 527
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
528
        self._titles = self._titles or ['ROC dev.', 'ROC eval.']
529 530
        self._x_label = self._x_label or 'FPR'
        self._y_label = self._y_label or "1 - FNR"
531 532 533
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
534
        # custom defaults
535
        if self._axlim is None:
536
            self._axlim = [None, None, -0.05, 1.05]
537

538
    def compute(self, idx, input_scores, input_names):
539 540
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
541 542
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
543 544
        dev_file = input_names[0]
        if self._eval:
545
            eval_neg, eval_pos = neg_list[1], pos_list[1]
546 547
            eval_file = input_names[1]

548
        mpl.figure(1)
549
        if self._eval:
550 551
            plot.roc_for_far(
                dev_neg, dev_pos,
552
                far_values=plot.log_values(self._min_dig or -4),
553
                CAR=self._semilogx,
554
                color=self._colors[idx], linestyle=self._linestyles[idx],
555
                label=self._label('dev', dev_file, idx)
556 557 558 559
            )
            if self._split:
                mpl.figure(2)

560
            linestyle = '--' if not self._split else self._linestyles[idx]
561
            plot.roc_for_far(
562 563
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
564
                CAR=self._semilogx,
565
                color=self._colors[idx],
566
                label=self._label('eval.', eval_file, idx)
567
            )
568
            if self._far_at is not None:
569
                from .. import farfrr
570
                for line in self._far_at:
571
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
572 573
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
574
                    eval_fnmr = 1 - eval_fnmr
575 576
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
577
        else:
578 579
            plot.roc_for_far(
                dev_neg, dev_pos,
580
                far_values=plot.log_values(self._min_dig or -4),
581
                CAR=self._semilogx,
582
                color=self._colors[idx], linestyle=self._linestyles[idx],
583
                label=self._label('dev', dev_file, idx)
584 585
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
586

587 588
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
589

590 591
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
592
        self._titles = self._titles or ['DET dev.', 'DET eval.']
593 594
        self._x_label = self._x_label or 'FPR (%)'
        self._y_label = self._y_label or 'FNR (%)'
595
        self._legend_loc = self._legend_loc or 'upper right'
596 597
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
598
        # custom defaults here
599 600
        if self._x_rotation is None:
            self._x_rotation = 50
601

602 603 604 605 606 607
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

608
    def compute(self, idx, input_scores, input_names):
609 610
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
611 612
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
613 614
        dev_file = input_names[0]
        if self._eval:
615
            eval_neg, eval_pos = neg_list[1], pos_list[1]
616 617
            eval_file = input_names[1]

618
        mpl.figure(1)
619
        if self._eval and eval_neg is not None:
620 621
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
622
                linestyle=self._linestyles[idx],
623
                label=self._label('dev.', dev_file, idx)
624 625 626
            )
            if self._split:
                mpl.figure(2)
627
            linestyle = '--' if not self._split else self._linestyles[idx]
628
            plot.det(
629
                eval_neg, eval_pos, self._points, color=self._colors[idx],
630
                linestyle=linestyle,
631
                label=self._label('eval.', eval_file, idx)
632
            )
633 634 635 636
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
637 638
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
639 640 641
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
642 643 644
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
645
                linestyle=self._linestyles[idx],
646
                label=self._label('dev.', dev_file, idx)
647 648
            )

649
    def _set_axis(self):
650
        plot.det_axis(self._axlim)
651

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
652

653 654
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
655

656
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
657
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
658
        if self._min_arg != 2:
659
            raise click.UsageError("EPC requires dev. and eval. score files")
660
        self._titles = self._titles or ['EPC'] * 2
661
        self._x_label = self._x_label or r'$\alpha$'
662
        self._y_label = self._y_label or hter + ' (%)'
663
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
664
        self._eval = True  # always eval data with EPC
665
        self._split = False
666
        self._nb_figs = 1
667
        self._far_at = None
668

669
    def compute(self, idx, input_scores, input_names):
670
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
671 672
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
673 674
        dev_file = input_names[0]
        if self._eval:
675
            eval_neg, eval_pos = neg_list[1], pos_list[1]
676 677
            eval_file = input_names[1]

678
        plot.epc(
679
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
680
            color=self._colors[idx], linestyle=self._linestyles[idx],
681
            label=self._label(
682
                'curve', dev_file + "_" + eval_file, idx
683
            )
684 685
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
686

687
class Hist(PlotBase):
688
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
689

690
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
691
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
692 693 694
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
695
            self._nbins, nhist_per_system, 'n_bins',
696
            'histograms')
697
        self._thres = ctx.meta.get('thres')
698 699
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
700
        self._criterion = ctx.meta.get('criterion')
701
        # no vertical (threshold) is displayed
702
        self._no_line = ctx.meta.get('no_line', False)
703
        # subplot grid
704 705
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
706
        # do not display dev histo
707
        self._hide_dev = ctx.meta.get('hide_dev', False)
708
        if self._hide_dev and not self._eval:
709 710
            raise click.BadParameter(
                "You can only use --hide-dev along with --eval")
711
        # dev hist are displayed next to eval hist
712
        self._nrows *= 1 if self._hide_dev or not self._eval else 2
713
        self._nlegends = ctx.meta.get('legends_ncol', 3)
714
        self._legend_loc = self._legend_loc or 'upper center'
715
        # number of subplot on one page
716
        self._step_print = int(self._nrows * self._ncols)
717
        self._title_base = 'Scores'
718 719
        self._y_label = self._y_label or 'Probability density'
        self._x_label = self._x_label or 'Score values'
720
        self._end_setup_plot = False
721
        # overide _titles of PlotBase
722
        self._titles = ctx.meta.get('titles', []) * 2
723

724
    def compute(self, idx, input_scores, input_names):
725
        ''' Draw histograms of negative and positive scores.'''
726
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
727
            self._get_neg_pos_thres(idx, input_scores, input_names)
728 729 730 731 732 733 734 735
        # keep id of the current system
        sys = idx
        # if the id of the current system does not match the id of the plot, 
        # change it
        if not self._hide_dev and self._eval:
            row = int(idx / self._ncols) * 2
            col = idx % self._ncols
            idx = col + self._ncols * row
736 737

        if not self._hide_dev or not self._eval:
738
            self._print_subplot(idx, sys, dev_neg, dev_pos, threshold,
739
                                not self._no_line, False)
740 741

        if self._eval:
742 743
            idx += self._ncols if not self._hide_dev else 0
            self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
744
                                not self._no_line, True)
745

746
    def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation):
747
        ''' print a subplot for the given score and subplot index'''
748 749 750 751 752 753 754
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
755 756 757
        # systems per page
        sys_per_page = self._step_print / (1 if self._hide_dev or not
                                           self._eval else 2)
758
        # rest to be printed
759 760 761 762 763
        sys_idx = sys % sys_per_page
        rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page
        # lower histo only
        is_lower = evaluation or not self._eval
        if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print):
764
            axis.set_xlabel(self._x_label)
765
        dflt_title = "Eval. scores" if evaluation else "Dev. scores"
766 767
        if self.n_systems == 1 and (not self._eval or self._hide_dev):
            dflt_title = " "
768 769
        add = self.n_systems if is_lower else 0
        axis.set_title(self._get_title(sys + add, dflt_title))
770
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
771
            '' if self._criterion is None else
772 773
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
774
        if draw_line:
775
            self._lines(threshold, label, neg, pos, idx)
776

777

778 779
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
780 781
        if self._step_print == sub_plot_idx or (is_lower and sys ==
                                                self.n_systems - 1):
782 783
            # print legend on the page
            self.plot_legends()
784
            mpl.tight_layout()
785
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
786 787
            mpl.clf()
            mpl.figure()
788

789
    def _get_title(self, idx, dflt=None):
790
        ''' Get the histo title for the given idx'''
791 792
        title = self._titles[idx] if self._titles is not None \
            and idx < len(self._titles) else dflt
793
        title = title or self._title_base
794 795
        title = '' if title is not None and not title.replace(
            ' ', '') else title
796
        return title or ''
797

798 799
    def plot_legends(self):
        ''' Print legend on current page'''
800 801 802
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
803 804 805 806 807 808 809
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

810 811
        if self._disp_legend:
            mpl.gcf().legend(
812
                lines, labels, loc=self._legend_loc, fancybox=True,
813
                framealpha=0.5, ncol=self._nlegends,
814
                bbox_to_anchor=(0.55, 1.1),
815
            )
816

817
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
818
        ''' Get scores and threshod for the given system at index idx'''
819 820
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
821 822 823 824 825 826 827 828 829 830
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
831
        # can have several files for one system
832 833
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
834 835
        eval_neg = eval_pos = None
        if self._eval:
836 837
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
838

839
        threshold = utils.get_thres(
840
            self._criterion, dev_neg[0], dev_pos[0]
841
        ) if self._thres is None else self._thres[idx]
842
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
843

844
    def _density_hist(self, scores, n, **kwargs):
845
        ''' Plots one density histo'''
846
        n, bins, patches = mpl.hist(
847
            scores, density=True,
848
            bins=self._nbins[n],
849
            **kwargs
850 851 852
        )
        return (n, bins, patches)

853 854
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
855
        ''' Plots vertical line at threshold '''
856
        label = label or 'Threshold'
857 858 859 860
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
861
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
862 863

    def _setup_hist(self, neg, pos):
864 865 866 867 868
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
869
        self._density_hist(
870 871
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
872 873
        )
        self._density_hist(
874 875
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
876
        )