figure.py 35 KB
Newer Older
1 2 3 4
'''Runs error analysis on score sets, outputs metrics and plots'''

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import os.path
8
import numpy
9 10 11 12 13
import click
import matplotlib
import matplotlib.pyplot as mpl
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import (far_threshold, plot, utils, ppndf)
15 16 17
import logging

LOGGER = logging.getLogger("bob.measure")
18

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
19

20 21 22 23 24 25 26 27 28 29 30 31
def check_list_value(values, desired_number, name, name2='systems'):
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
                '#{} ({}) must be either 1 value or the same as '
                '#{} ({} values)'.format(name, values, name2, desired_number))

    return values


32 33 34 35 36 37 38 39 40 41
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
42 43
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

44
    def __init__(self, ctx, scores, evaluation, func_load):
45 46 47 48 49 50 51
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
52 53 54 55
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
56 57 58 59 60
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
61
        self._legends = ctx.meta.get('legends')
62
        self._eval = evaluation
63
        self._min_arg = ctx.meta.get('min_arg', 1)
64 65 66 67 68
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
                'Number of argument must be a non-zero multiple of %d' % self._min_arg
            )
        self.n_systems = int(len(scores) / self._min_arg)
69 70
        if self._legends is not None and len(self._legends) < self.n_systems:
            raise click.BadParameter("Number of legends must be >= to the "
71
                                     "number of systems")
72 73 74 75 76 77 78 79 80 81

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
82
        # init matplotlib, log files, ...
83
        self.init_process()
84 85
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
86 87
        # Note that more than one dev or eval scores score can be passed to
        # each system
88
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
89
            # load scores for each system: get the corresponding arrays and
90
            # base-name of files
91
            input_scores, input_names = self._load_files(
92 93 94 95 96 97
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
98
                self._scores[idx * self._min_arg:(idx + 1) * self._min_arg]
99 100
            )
            self.compute(idx, input_scores, input_names)
101
        # setup final configuration, plotting properties, ...
102 103
        self.end_process()

104
    # protected functions that need to be overwritten
105 106
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
107
        before iterating through the different systems.
108 109 110
        Should reimplemented in derived classes"""
        pass

111
    # Main computations are done here in the subclasses
112
    @abstractmethod
113
    def compute(self, idx, input_scores, input_names):
114
        """Compute metrics or plots from the given scores provided by
115 116 117 118 119 120 121
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
122 123 124 125
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
126 127
        """
        pass
128 129 130 131 132 133 134
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

135
    # Things to do after the main iterative computations are done
136 137
    @abstractmethod
    def end_process(self):
138
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
139
        after iterating through the different systems.
140
        Should reimplemented in derived classes"""
141 142
        pass

143
    # common protected functions
144

145 146
    def _load_files(self, filepaths):
        ''' Load the input files and return the base names of the files
147 148 149

        Returns
        -------
150 151 152 153 154
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
                A list of basenames for the given files
155
        '''
156 157 158 159 160 161
        scores = []
        basenames = []
        for filename in filepaths:
            basenames.append(os.path.basename(filename).split(".")[0])
            scores.append(self.func_load(filename))
        return scores, basenames
162

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
163

164 165 166 167 168 169 170 171
class Metrics(MeasureBase):
    ''' Compute metrics from score files

    Attributes
    ----------
    log_file: str
        output stream
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
172

173
    def __init__(self, ctx, scores, evaluation, func_load,
174 175
                 names=('False Positive Rate', 'False Negative Rate',
                        'F1-score', 'Precision', 'Recall')):
176
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
177
        self.names = names
178 179 180 181
        self._tablefmt = ctx.meta.get('tablefmt')
        self._criterion = ctx.meta.get('criterion')
        self._open_mode = ctx.meta.get('open_mode')
        self._thres = ctx.meta.get('thres')
182
        self._decimal = ctx.meta.get('decimal', 2)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
183
        if self._thres is not None:
184
            if len(self._thres) == 1:
185 186
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
187
                raise click.BadParameter(
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
188
                    '#thresholds must be the same as #systems (%d)'
189
                    % len(self.n_systems)
190
                )
191 192
        self._far = ctx.meta.get('far_value')
        self._log = ctx.meta.get('log')
193 194 195 196
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

197 198 199
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

200
    def _numbers(self, neg, pos, threshold, fta):
201 202
        from .. import (farfrr, precision_recall, f_score)
        # fpr and fnr
203
        fmr, fnmr = farfrr(neg, pos, threshold)
204
        hter = (fmr + fnmr) / 2.0
205 206 207 208 209 210 211 212
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
        # precision and recall
        precision, recall = precision_recall(neg, pos, threshold)

        # f_score
        f1_score = f_score(neg, pos, threshold, 1)
        return (fta, fmr, fnmr, hter, far, frr, fm, ni, fnm, nc, precision,
                recall, f1_score)

    def _strings(self, metrics):
        fta_str = "%.1f%%" % (100 * metrics[0])
        fmr_str = "%.1f%% (%d/%d)" % (100 * metrics[1], metrics[6], metrics[7])
        fnmr_str = "%.1f%% (%d/%d)" % (100 * metrics[2], metrics[8], metrics[9])
        far_str = "%.1f%%" % (100 * metrics[4])
        frr_str = "%.1f%%" % (100 * metrics[5])
        hter_str = "%.1f%%" % (100 * metrics[3])
        prec_str = "%.1f" % (metrics[10])
        recall_str = "%.1f" % (metrics[11])
        f1_str = "%.1f" % (metrics[12])

        return (fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str,
                prec_str, recall_str, f1_str)

    def _get_all_metrics(self, idx, input_scores, input_names):
        ''' Compute all metrics for dev and eval scores'''
237 238 239 240 241 242
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]

243
        threshold = self.get_thres(self._criterion, dev_neg, dev_pos, self._far) \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
244
            if self._thres is None else self._thres[idx]
245

246
        title = self._legends[idx] if self._legends is not None else None
247
        if self._thres is None:
248
            far_str = ''
249
            if self._criterion == 'far' and self._far is not None:
250
                far_str = str(self._far)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
251
            click.echo("[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
252 253 254
                       % (self._criterion.upper(),
                          far_str, title or dev_file,
                          threshold),
255 256
                       file=self.log_file)
        else:
257
            click.echo("[Min. criterion: user provided] Threshold on "
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
258
                       "Development set `%s`: %e"
259
                       % (dev_file or title, threshold), file=self.log_file)
260

261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
        res = []
        res.append(self._strings(self._numbers(
            dev_neg, dev_pos, threshold, dev_fta)))

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            res.append(self._strings(self._numbers(
                eval_neg, eval_pos, threshold, eval_fta)))
        else:
            res.append(None)

        return res

    def compute(self, idx, input_scores, input_names):
        ''' Compute metrics thresholds and tables (FPR, FNR, precision, recall,
        f1_score) for given system inputs'''
        dev_file = input_names[0]
        title = self._legends[idx] if self._legends is not None else None
        all_metrics = self._get_all_metrics(idx, input_scores, input_names)
        fta_dev = float(all_metrics[0][0].replace('%', ''))
        if fta_dev > 0.0:
            click.echo("NaNs scores (%s) were found in %s" %
                       (all_metrics[0][0], dev_file))
            LOGGER.warn("NaNs scores (%s) were found in %s", all_metrics[0][0],
                        dev_file)
        headers = [' ' or title, 'Development']
        rows = [[self.names[0], all_metrics[0][1]],
                [self.names[1], all_metrics[0][2]],
                [self.names[2], all_metrics[0][6]],
                [self.names[3], all_metrics[0][7]],
                [self.names[4], all_metrics[0][8]]]
293

294
        if self._eval:
295 296 297 298 299 300 301
            eval_file = input_names[1]
            fta_eval = float(all_metrics[1][0].replace('%', ''))
            if fta_eval > 0.0:
                click.echo("NaNs scores (%s) were found in %s" %
                           (all_metrics[1][0], eval_file))
                LOGGER.warn("NaNs scores (%s) were found in %s",
                            all_metrics[1][0], eval_file)
302 303
            # computes statistics for the eval set based on the threshold a
            # priori
304 305 306 307 308 309
            headers.append('Evaluation')
            rows[0].append(all_metrics[1][1])
            rows[1].append(all_metrics[1][2])
            rows[2].append(all_metrics[1][6])
            rows[3].append(all_metrics[1][7])
            rows[4].append(all_metrics[1][8])
310

311
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
312 313 314 315 316 317

    def end_process(self):
        ''' Close log file if needed'''
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
318

319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
class MultiMetrics(Metrics):
    '''Computes average of metrics based on several protocols (cross
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
    '''

    def __init__(self, ctx, scores, evaluation, func_load,
                 names=('NaNs Rate', 'False Positive Rate',
                        'False Negative Rate', 'False Accept Rate',
                        'False Reject Rate', 'Half Total Error Rate')):
        super(MultiMetrics, self).__init__(
            ctx, scores, evaluation, func_load, names=names)

        self.headers = ['Methods'] + list(self.names)
        if self._eval:
            self.headers.insert(1, self.names[5] + ' (dev)')
        self.rows = []

    def _strings(self, metrics):
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _ = metrics.std(axis=0)
        fta_str = "%.1f%% (%.1f%%)" % (100 * ftam, 100 * ftas)
        fmr_str = "%.1f%% (%.1f%%)" % (100 * fmrm, 100 * fmrs)
        fnmr_str = "%.1f%% (%.1f%%)" % (100 * fnmrm, 100 * fnmrs)
        far_str = "%.1f%% (%.1f%%)" % (100 * farm, 100 * fars)
        frr_str = "%.1f%% (%.1f%%)" % (100 * frrm, 100 * frrs)
        hter_str = "%.1f%% (%.1f%%)" % (100 * hterm, 100 * hters)

        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
        '''Computes the average of metrics over several protocols.'''
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
            threshold = self.get_thres(self._criterion, neg, pos, self._far) \
                if self._thres is None else self._thres[idx]
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
                self._eval_metrics.append(
                    self._numbers(neg, pos, threshold, fta))
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
            self._strings(self._dev_metrics)

        if self._eval:
            self.rows.append([title, hter_str])
        else:
            self.rows.append([title, fta_str, fmr_str, fnmr_str,
                              far_str, frr_str, hter_str])

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = \
                self._strings(self._eval_metrics)

            self.rows[-1].extend([fta_str, fmr_str, fnmr_str,
                                  far_str, frr_str, hter_str])

    def end_process(self):
        click.echo(tabulate(self.rows, self.headers,
                            self._tablefmt), file=self.log_file)
        super(MultiMetrics, self).end_process()


404 405 406 407
class PlotBase(MeasureBase):
    ''' Base class for plots. Regroup several options and code
    shared by the different plots
    '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
408

409 410
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
411 412 413 414 415 416
        self._output = ctx.meta.get('output')
        self._points = ctx.meta.get('points', 100)
        self._split = ctx.meta.get('split')
        self._axlim = ctx.meta.get('axlim')
        self._disp_legend = ctx.meta.get('disp_legend', True)
        self._legend_loc = ctx.meta.get('legend_loc')
417 418 419
        self._min_dig = None
        if 'min_far_value' in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta['min_far_value']))
420
        elif self._axlim is not None and self._axlim[0] is not None:
421 422
            self._min_dig = int(math.log10(self._axlim[0])
                                if self._axlim[0] != 0 else 0)
423 424
        self._clayout = ctx.meta.get('clayout')
        self._far_at = ctx.meta.get('lines_at')
425 426 427 428
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
429 430
        self._print_fn = ctx.meta.get('show_fn', True)
        self._x_rotation = ctx.meta.get('x_rotation')
431 432
        if 'style' in ctx.meta:
            mpl.style.use(ctx.meta['style'])
433
        self._nb_figs = 2 if self._eval and self._split else 1
434
        self._colors = utils.get_colors(self.n_systems)
435
        self._line_linestyles = ctx.meta.get('line_linestyles', False)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
436 437
        self._linestyles = utils.get_linestyles(
            self.n_systems, self._line_linestyles)
438
        self._titles = ctx.meta.get('titles', []) * 2
439 440 441 442 443
        # for compatibility
        self._title = ctx.meta.get('title')
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

444 445
        self._x_label = ctx.meta.get('x_label')
        self._y_label = ctx.meta.get('y_label')
446 447 448 449 450 451 452 453 454 455
        self._grid_color = 'silver'
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
        ''' Open pdf and set axis font size if provided '''
        if not hasattr(matplotlib, 'backends'):
            matplotlib.use('pdf')

        self._pdf_page = self._ctx.meta['PdfPages'] if 'PdfPages'in \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
456
            self._ctx.meta else PdfPages(self._output)
457

458
        for i in range(self._nb_figs):
459
            fs = self._ctx.meta.get('figsize')
460 461
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
462
            fig.clear()
463 464

    def end_process(self):
465 466
        ''' Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
467
        # draw vertical lines
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
                mpl.plot(
                    [line_trans, line_trans], [-100.0, 100.], "--",
                    color='black'
                )
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
                    sort_indice = sorted(
                        range(len(x_values)), key=x_values.__getitem__
                    )
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
                    mpl.plot(x_values,
                             y_values, '--',
                             color='black')
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
487
        # only for plots
488 489 490
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
491 492
                title = '' if not self._titles else self._titles[i]
                mpl.title(title if title.replace(' ', '') else '')
493 494 495
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
496 497
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
498
                self._set_axis()
499 500 501
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
502
        # do not want to close PDF when running evaluate
503 504 505 506
        if 'PdfPages' in self._ctx.meta and \
           ('closef' not in self._ctx.meta or self._ctx.meta['closef']):
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
507
    # common protected functions
508 509

    def _label(self, base, name, idx):
510 511
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
512
        if self.n_systems > 1:
513 514 515
            return base + (" %d (%s)" % (idx + 1, name))
        return base + (" (%s)" % name)

516
    def _set_axis(self):
517
        if self._axlim is not None:
518
            mpl.axis(self._axlim)
519

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
520

521
class Roc(PlotBase):
522
    ''' Handles the plotting of ROC'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
523

524 525
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
526
        self._titles = self._titles or ['ROC dev.', 'ROC eval.']
527
        self._x_label = self._x_label or 'False Positive Rate'
528
        self._y_label = self._y_label or "1 - False Negative Rate"
529 530 531
        self._semilogx = ctx.meta.get('semilogx', True)
        best_legend = 'lower right' if self._semilogx else 'upper right'
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
532
        # custom defaults
533
        if self._axlim is None:
534
            self._axlim = [None, None, -0.05, 1.05]
535

536
    def compute(self, idx, input_scores, input_names):
537 538
        ''' Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`'''
539 540
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
541 542
        dev_file = input_names[0]
        if self._eval:
543
            eval_neg, eval_pos = neg_list[1], pos_list[1]
544 545
            eval_file = input_names[1]

546
        mpl.figure(1)
547
        if self._eval:
548 549
            plot.roc_for_far(
                dev_neg, dev_pos,
550
                far_values=plot.log_values(self._min_dig or -4),
551
                CAR=self._semilogx,
552
                color=self._colors[idx], linestyle=self._linestyles[idx],
553
                label=self._label('dev', dev_file, idx)
554 555 556 557
            )
            if self._split:
                mpl.figure(2)

558
            linestyle = '--' if not self._split else self._linestyles[idx]
559
            plot.roc_for_far(
560 561
                eval_neg, eval_pos, linestyle=linestyle,
                far_values=plot.log_values(self._min_dig or -4),
562
                CAR=self._semilogx,
563
                color=self._colors[idx],
564
                label=self._label('eval.', eval_file, idx)
565
            )
566
            if self._far_at is not None:
567
                from .. import farfrr
568
                for line in self._far_at:
569
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
570 571
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
572
                    eval_fnmr = 1 - eval_fnmr
573 574
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
575
        else:
576 577
            plot.roc_for_far(
                dev_neg, dev_pos,
578
                far_values=plot.log_values(self._min_dig or -4),
579
                CAR=self._semilogx,
580
                color=self._colors[idx], linestyle=self._linestyles[idx],
581
                label=self._label('dev', dev_file, idx)
582 583
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
584

585 586
class Det(PlotBase):
    ''' Handles the plotting of DET '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
587

588 589
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
590
        self._titles = self._titles or ['DET dev.', 'DET eval.']
591 592
        self._x_label = self._x_label or 'False Positive Rate (%)'
        self._y_label = self._y_label or 'False Negative Rate (%)'
593
        self._legend_loc = self._legend_loc or 'upper right'
594 595
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
596
        # custom defaults here
597 598
        if self._x_rotation is None:
            self._x_rotation = 50
599

600 601 602 603 604 605
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

606
    def compute(self, idx, input_scores, input_names):
607 608
        ''' Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`'''
609 610
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
611 612
        dev_file = input_names[0]
        if self._eval:
613
            eval_neg, eval_pos = neg_list[1], pos_list[1]
614 615
            eval_file = input_names[1]

616
        mpl.figure(1)
617
        if self._eval and eval_neg is not None:
618 619
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
620
                linestyle=self._linestyles[idx],
621
                label=self._label('development', dev_file, idx)
622 623 624
            )
            if self._split:
                mpl.figure(2)
625
            linestyle = '--' if not self._split else self._linestyles[idx]
626
            plot.det(
627
                eval_neg, eval_pos, self._points, color=self._colors[idx],
628
                linestyle=linestyle,
629
                label=self._label('eval.', eval_file, idx)
630
            )
631 632 633 634
            if self._far_at is not None:
                from .. import farfrr
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
635 636
                    eval_fmr, eval_fnmr = farfrr(
                        eval_neg, eval_pos, thres_line)
637 638 639
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
640 641 642
        else:
            plot.det(
                dev_neg, dev_pos, self._points, color=self._colors[idx],
643
                linestyle=self._linestyles[idx],
644
                label=self._label('development', dev_file, idx)
645 646
            )

647
    def _set_axis(self):
648
        plot.det_axis(self._axlim)
649

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
650

651 652
class Epc(PlotBase):
    ''' Handles the plotting of EPC '''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
653

654
    def __init__(self, ctx, scores, evaluation, func_load, hter='HTER'):
655
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
656
        if self._min_arg != 2:
657
            raise click.UsageError("EPC requires dev. and eval. score files")
658
        self._titles = self._titles or ['EPC'] * 2
659
        self._x_label = self._x_label or r'$\alpha$'
660
        self._y_label = self._y_label or hter + ' (%)'
661
        self._legend_loc = self._legend_loc or 'upper center'
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
662
        self._eval = True  # always eval data with EPC
663
        self._split = False
664
        self._nb_figs = 1
665
        self._far_at = None
666

667
    def compute(self, idx, input_scores, input_names):
668
        ''' Plot EPC using :py:func:`bob.measure.plot.epc` '''
669 670
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
671 672
        dev_file = input_names[0]
        if self._eval:
673
            eval_neg, eval_pos = neg_list[1], pos_list[1]
674 675
            eval_file = input_names[1]

676
        plot.epc(
677
            dev_neg, dev_pos, eval_neg, eval_pos, self._points,
678
            color=self._colors[idx], linestyle=self._linestyles[idx],
679
            label=self._label(
680
                'curve', dev_file + "_" + eval_file, idx
681
            )
682 683
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
684

685
class Hist(PlotBase):
686
    ''' Functional base class for histograms'''
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
687

688
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
689
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
690 691 692
        self._nbins = ctx.meta.get('n_bins', ['doane'])
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
693
            self._nbins, nhist_per_system, 'n_bins',
694
            'histograms')
695
        self._thres = ctx.meta.get('thres')
696 697
        self._thres = check_list_value(
            self._thres, self.n_systems, 'thresholds')
698
        self._criterion = ctx.meta.get('criterion')
699
        # no vertical (threshold) is displayed
700
        self._no_line = ctx.meta.get('no_line', False)
701
        # subplot grid
702 703
        self._nrows = ctx.meta.get('n_row', 1)
        self._ncols = ctx.meta.get('n_col', 1)
704
        # do not display dev histo
705
        self._hide_dev = ctx.meta.get('hide_dev', False)
706
        if self._hide_dev and not self._eval:
707 708
            raise click.BadParameter(
                "You can only use --hide-dev along with --eval")
709
        # dev hist are displayed next to eval hist
710
        self._nrows *= 1 if self._hide_dev or not self._eval else 2
711
        self._nlegends = ctx.meta.get('legends_ncol', 3)
712
        self._legend_loc = self._legend_loc or 'upper center'
713
        # number of subplot on one page
714
        self._step_print = int(self._nrows * self._ncols)
715
        self._title_base = 'Scores'
716
        self._y_label = 'Probability density'
717
        self._x_label = 'Score values'
718
        self._end_setup_plot = False
719
        # overide _titles of PlotBase
720
        self._titles = ctx.meta.get('titles', []) * 2
721

722
    def compute(self, idx, input_scores, input_names):
723
        ''' Draw histograms of negative and positive scores.'''
724
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = \
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
725
            self._get_neg_pos_thres(idx, input_scores, input_names)
726 727 728 729 730 731 732 733
        # keep id of the current system
        sys = idx
        # if the id of the current system does not match the id of the plot, 
        # change it
        if not self._hide_dev and self._eval:
            row = int(idx / self._ncols) * 2
            col = idx % self._ncols
            idx = col + self._ncols * row
734 735

        if not self._hide_dev or not self._eval:
736
            self._print_subplot(idx, sys, dev_neg, dev_pos, threshold,
737
                                not self._no_line, False)
738 739

        if self._eval:
740 741
            idx += self._ncols if not self._hide_dev else 0
            self._print_subplot(idx, sys, eval_neg, eval_pos, threshold,
742
                                not self._no_line, True)
743

744
    def _print_subplot(self, idx, sys, neg, pos, threshold, draw_line, evaluation):
745
        ''' print a subplot for the given score and subplot index'''
746 747 748 749 750 751 752
        n = idx % self._step_print
        col = n % self._ncols
        sub_plot_idx = n + 1
        axis = mpl.subplot(self._nrows, self._ncols, sub_plot_idx)
        self._setup_hist(neg, pos)
        if col == 0:
            axis.set_ylabel(self._y_label)
753 754 755
        # systems per page
        sys_per_page = self._step_print / (1 if self._hide_dev or not
                                           self._eval else 2)
756
        # rest to be printed
757 758 759 760 761
        sys_idx = sys % sys_per_page
        rest_print = self.n_systems - int(sys / sys_per_page) * sys_per_page
        # lower histo only
        is_lower = evaluation or not self._eval
        if is_lower and sys_idx + self._ncols >= min(sys_per_page, rest_print):
762
            axis.set_xlabel(self._x_label)
763
        dflt_title = "Eval. scores" if evaluation else "Dev. scores"
764 765
        if self.n_systems == 1 and (not self._eval or self._hide_dev):
            dflt_title = " "
766 767
        add = self.n_systems if is_lower else 0
        axis.set_title(self._get_title(sys + add, dflt_title))
768
        label = "%s threshold%s" % (
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
769
            '' if self._criterion is None else
770 771
            self._criterion.upper(), ' (dev)' if self._eval else ''
        )
772
        if draw_line:
773
            self._lines(threshold, label, neg, pos, idx)
774

775

776 777
        # if it was the last subplot of the page or the last subplot
        # to display, save figure
778 779
        if self._step_print == sub_plot_idx or (is_lower and sys ==
                                                self.n_systems - 1):
780 781
            # print legend on the page
            self.plot_legends()
782
            mpl.tight_layout()
783
            self._pdf_page.savefig(mpl.gcf(), bbox_inches='tight')
784 785
            mpl.clf()
            mpl.figure()
786

787
    def _get_title(self, idx, dflt=None):
788
        ''' Get the histo title for the given idx'''
789 790
        title = self._titles[idx] if self._titles is not None \
            and idx < len(self._titles) else dflt
791
        title = title or self._title_base
792 793
        title = '' if title is not None and not title.replace(
            ' ', '') else title
794
        return title or ''
795

796 797
    def plot_legends(self):
        ''' Print legend on current page'''
798 799 800
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
801 802 803 804 805 806 807
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

808 809
        if self._disp_legend:
            mpl.gcf().legend(
810
                lines, labels, loc=self._legend_loc, fancybox=True,
811
                framealpha=0.5, ncol=self._nlegends,
812
                bbox_to_anchor=(0.55, 1.1),
813
            )
814

815
    def _get_neg_pos_thres(self, idx, input_scores, input_names):
816
        ''' Get scores and threshod for the given system at index idx'''
817 818
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        length = len(neg_list)
819 820 821 822 823 824 825 826 827 828
        # lists returned by get_fta_list contains all the following items:
        # for bio or measure without eval:
        #   [dev]
        # for vuln with {licit,spoof} with eval:
        #   [dev, eval]
        # for vuln with {licit,spoof} without eval:
        #   [licit_dev, spoof_dev]
        # for vuln with {licit,spoof} with eval:
        #   [licit_dev, licit_eval, spoof_dev, spoof_eval]
        step = 2 if self._eval else 1
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
829
        # can have several files for one system
830 831
        dev_neg = [neg_list[x] for x in range(0, length, step)]
        dev_pos = [pos_list[x] for x in range(0, length, step)]
832 833
        eval_neg = eval_pos = None
        if self._eval:
834 835
            eval_neg = [neg_list[x] for x in range(1, length, step)]
            eval_pos = [pos_list[x] for x in range(1, length, step)]
836

837
        threshold = utils.get_thres(
838
            self._criterion, dev_neg[0], dev_pos[0]
839
        ) if self._thres is None else self._thres[idx]
840
        return dev_neg, dev_pos, eval_neg, eval_pos, threshold
841

842
    def _density_hist(self, scores, n, **kwargs):
843
        ''' Plots one density histo'''
844
        n, bins, patches = mpl.hist(
845
            scores, density=True,
846
            bins=self._nbins[n],
847
            **kwargs
848 849 850
        )
        return (n, bins, patches)

851 852
    def _lines(self, threshold, label=None, neg=None, pos=None,
               idx=None, **kwargs):
853
        ''' Plots vertical line at threshold '''
854
        label = label or 'Threshold'
855 856 857 858
        kwargs.setdefault('color', 'C3')
        kwargs.setdefault('linestyle', '--')
        kwargs.setdefault('label', label)
        # plot a vertical threshold line
859
        mpl.axvline(x=threshold, ymin=0, ymax=1, **kwargs)
860 861

    def _setup_hist(self, neg, pos):
862 863 864 865 866
        ''' This function can be overwritten in derived classes

        Plots all the density histo required in one plot. Here negative and
        positive scores densities.
        '''
867
        self._density_hist(
868 869
            neg[0], n=0,
            label='Negatives', alpha=0.5, color='C3'
870 871
        )
        self._density_hist(
872 873
            pos[0], n=1,
            label='Positives', alpha=0.5, color='C0'
874
        )