figure.py 40.3 KB
Newer Older
1
"""Runs error analysis on score sets, outputs metrics and plots"""
2 3 4

from __future__ import division, print_function
from abc import ABCMeta, abstractmethod
5
import math
6
import sys
7
import numpy
8 9 10
import click
import matplotlib
import matplotlib.pyplot as mpl
11
from matplotlib import gridspec
12 13
from matplotlib.backends.backend_pdf import PdfPages
from tabulate import tabulate
14
from .. import far_threshold, plot, utils, ppndf
15 16 17
import logging

LOGGER = logging.getLogger("bob.measure")
18

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
19

20
def check_list_value(values, desired_number, name, name2="systems"):
21 22 23 24 25
    if values is not None and len(values) != desired_number:
        if len(values) == 1:
            values = values * desired_number
        else:
            raise click.BadParameter(
26 27 28
                "#{} ({}) must be either 1 value or the same as "
                "#{} ({} values)".format(name, values, name2, desired_number)
            )
29 30 31 32

    return values


33 34 35 36 37 38 39 40 41 42
class MeasureBase(object):
    """Base class for metrics and plots.
    This abstract class define the framework to plot or compute metrics from a
    list of (positive, negative) scores tuples.

    Attributes
    ----------
    func_load:
        Function that is used to load the input files
    """
43

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
44 45
    __metaclass__ = ABCMeta  # for python 2.7 compatibility

46
    def __init__(self, ctx, scores, evaluation, func_load):
47 48 49 50 51 52 53
        """
        Parameters
        ----------
        ctx : :py:class:`dict`
            Click context dictionary.

        scores : :any:`list`:
54 55 56 57
            List of input files (e.g. dev-{1, 2, 3}, {dev,eval}-scores1
            {dev,eval}-scores2)
        eval : :py:class:`bool`
            True if eval data are used
58 59 60 61 62
        func_load : Function that is used to load the input files
        """
        self._scores = scores
        self._ctx = ctx
        self.func_load = func_load
63
        self._legends = ctx.meta.get("legends")
64
        self._eval = evaluation
65
        self._min_arg = ctx.meta.get("min_arg", 1)
66 67
        if len(scores) < 1 or len(scores) % self._min_arg != 0:
            raise click.BadParameter(
68
                "Number of argument must be a non-zero multiple of %d" % self._min_arg
69 70
            )
        self.n_systems = int(len(scores) / self._min_arg)
71
        if self._legends is not None and len(self._legends) < self.n_systems:
72 73 74
            raise click.BadParameter(
                "Number of legends must be >= to the " "number of systems"
            )
75 76 77 78 79 80 81 82 83 84

    def run(self):
        """ Generate outputs (e.g. metrics, files, pdf plots).
        This function calls abstract methods
        :func:`~bob.measure.script.figure.MeasureBase.init_process` (before
        loop), :py:func:`~bob.measure.script.figure.MeasureBase.compute`
        (in the loop iterating through the different
        systems) and :py:func:`~bob.measure.script.figure.MeasureBase.end_process`
        (after the loop).
        """
85
        # init matplotlib, log files, ...
86
        self.init_process()
87 88
        # iterates through the different systems and feed `compute`
        # with the dev (and eval) scores of each system
89 90
        # Note that more than one dev or eval scores score can be passed to
        # each system
91
        for idx in range(self.n_systems):
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
92
            # load scores for each system: get the corresponding arrays and
93
            # base-name of files
94
            input_scores, input_names = self._load_files(
95 96 97 98 99 100
                # Scores are given as followed:
                # SysA-dev SysA-eval ... SysA-XX  SysB-dev SysB-eval ... SysB-XX
                # ------------------------------  ------------------------------
                #   First set of `self._min_arg`     Second set of input files
                #     input files starting at               for SysB
                #    index idx * self._min_arg
101
                self._scores[idx * self._min_arg : (idx + 1) * self._min_arg]
102
            )
103 104 105 106 107 108 109 110 111 112 113
            LOGGER.info("-----Input files for system %d-----", idx + 1)
            for i, name in enumerate(input_names):
                if not self._eval:
                    LOGGER.info("Dev. score %d: %s", i + 1, name)
                else:
                    if i % 2 == 0:
                        LOGGER.info("Dev. score %d: %s", i / 2 + 1, name)
                    else:
                        LOGGER.info("Eval. score %d: %s", i / 2 + 1, name)
            LOGGER.info("----------------------------------")

114
            self.compute(idx, input_scores, input_names)
115
        # setup final configuration, plotting properties, ...
116 117
        self.end_process()

118
    # protected functions that need to be overwritten
119 120
    def init_process(self):
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
121
        before iterating through the different systems.
122 123 124
        Should reimplemented in derived classes"""
        pass

125
    # Main computations are done here in the subclasses
126
    @abstractmethod
127
    def compute(self, idx, input_scores, input_names):
128
        """Compute metrics or plots from the given scores provided by
129 130 131 132 133 134 135
        :py:func:`~bob.measure.script.figure.MeasureBase.run`.
        Should reimplemented in derived classes

        Parameters
        ----------
        idx : :obj:`int`
            index of the system
136 137 138 139
        input_scores: :any:`list`
            list of scores returned by the loading function
        input_names: :any:`list`
            list of base names for the input file of the system
140 141
        """
        pass
142 143 144 145 146 147 148
        # structure of input is (vuln example):
        # if evaluation is provided
        # [ (dev_licit_neg, dev_licit_pos), (eval_licit_neg, eval_licit_pos),
        #   (dev_spoof_neg, dev_licit_pos), (eval_spoof_neg, eval_licit_pos)]
        # and if only dev:
        # [ (dev_licit_neg, dev_licit_pos), (dev_spoof_neg, dev_licit_pos)]

149
    # Things to do after the main iterative computations are done
150 151
    @abstractmethod
    def end_process(self):
152
        """ Called in :py:func:`~bob.measure.script.figure.MeasureBase`.run
153
        after iterating through the different systems.
154
        Should reimplemented in derived classes"""
155 156
        pass

157
    # common protected functions
158

159
    def _load_files(self, filepaths):
160
        """ Load the input files and return the base names of the files
161 162 163

        Returns
        -------
164 165 166 167
            scores: :any:`list`:
                A list that contains the output of
                ``func_load`` for the given files
            basenames: :any:`list`:
168
                A list of the given files
169
        """
170 171 172
        scores = []
        basenames = []
        for filename in filepaths:
173
            basenames.append(filename)
174 175
            scores.append(self.func_load(filename))
        return scores, basenames
176

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
177

178
class Metrics(MeasureBase):
179
    """ Compute metrics from score files
180 181 182 183 184

    Attributes
    ----------
    log_file: str
        output stream
185
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
186

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
    def __init__(
        self,
        ctx,
        scores,
        evaluation,
        func_load,
        names=(
            "False Positive Rate",
            "False Negative Rate",
            "Precision",
            "Recall",
            "F1-score",
            "Area Under ROC Curve",
            "Area Under ROC Curve (log scale)",
        ),
    ):
203
        super(Metrics, self).__init__(ctx, scores, evaluation, func_load)
204
        self.names = names
205 206 207 208 209
        self._tablefmt = ctx.meta.get("tablefmt")
        self._criterion = ctx.meta.get("criterion")
        self._open_mode = ctx.meta.get("open_mode")
        self._thres = ctx.meta.get("thres")
        self._decimal = ctx.meta.get("decimal", 2)
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
210
        if self._thres is not None:
211
            if len(self._thres) == 1:
212 213
                self._thres = self._thres * self.n_systems
            elif len(self._thres) != self.n_systems:
214
                raise click.BadParameter(
215
                    "#thresholds must be the same as #systems (%d)"
216
                    % len(self.n_systems)
217
                )
218 219
        self._far = ctx.meta.get("far_value")
        self._log = ctx.meta.get("log")
220 221 222 223
        self.log_file = sys.stdout
        if self._log is not None:
            self.log_file = open(self._log, self._open_mode)

224 225 226
    def get_thres(self, criterion, dev_neg, dev_pos, far):
        return utils.get_thres(criterion, dev_neg, dev_pos, far)

227
    def _numbers(self, neg, pos, threshold, fta):
228 229
        from .. import farfrr, precision_recall, f_score, roc_auc_score

230
        # fpr and fnr
231
        fmr, fnmr = farfrr(neg, pos, threshold)
232
        hter = (fmr + fnmr) / 2.0
233 234 235 236 237 238 239 240
        far = fmr * (1 - fta)
        frr = fta + fnmr * (1 - fta)

        ni = neg.shape[0]  # number of impostors
        fm = int(round(fmr * ni))  # number of false accepts
        nc = pos.shape[0]  # number of clients
        fnm = int(round(fnmr * nc))  # number of false rejects

241 242 243 244 245
        # precision and recall
        precision, recall = precision_recall(neg, pos, threshold)

        # f_score
        f1_score = f_score(neg, pos, threshold, 1)
246 247 248

        # AUC ROC
        auc = roc_auc_score(neg, pos)
249
        auc_log = roc_auc_score(neg, pos, log_scale=True)
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        return (
            fta,
            fmr,
            fnmr,
            hter,
            far,
            frr,
            fm,
            ni,
            fnm,
            nc,
            precision,
            recall,
            f1_score,
            auc,
            auc_log,
        )
267 268

    def _strings(self, metrics):
269
        n_dec = ".%df" % self._decimal
270
        fta_str = "%s%%" % format(100 * metrics[0], n_dec)
271 272 273 274 275 276 277 278 279 280
        fmr_str = "%s%% (%d/%d)" % (
            format(100 * metrics[1], n_dec),
            metrics[6],
            metrics[7],
        )
        fnmr_str = "%s%% (%d/%d)" % (
            format(100 * metrics[2], n_dec),
            metrics[8],
            metrics[9],
        )
281 282 283 284 285 286
        far_str = "%s%%" % format(100 * metrics[4], n_dec)
        frr_str = "%s%%" % format(100 * metrics[5], n_dec)
        hter_str = "%s%%" % format(100 * metrics[3], n_dec)
        prec_str = "%s" % format(metrics[10], n_dec)
        recall_str = "%s" % format(metrics[11], n_dec)
        f1_str = "%s" % format(metrics[12], n_dec)
287
        auc_str = "%s" % format(metrics[13], n_dec)
288
        auc_log_str = "%s" % format(metrics[14], n_dec)
289

290 291 292 293 294 295 296 297 298 299 300 301 302
        return (
            fta_str,
            fmr_str,
            fnmr_str,
            far_str,
            frr_str,
            hter_str,
            prec_str,
            recall_str,
            f1_str,
            auc_str,
            auc_log_str,
        )
303 304

    def _get_all_metrics(self, idx, input_scores, input_names):
305
        """ Compute all metrics for dev and eval scores"""
306 307 308 309 310 311
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        dev_neg, dev_pos, dev_fta = neg_list[0], pos_list[0], fta_list[0]
        dev_file = input_names[0]
        if self._eval:
            eval_neg, eval_pos, eval_fta = neg_list[1], pos_list[1], fta_list[1]

312 313 314 315 316
        threshold = (
            self.get_thres(self._criterion, dev_neg, dev_pos, self._far)
            if self._thres is None
            else self._thres[idx]
        )
317

318
        title = self._legends[idx] if self._legends is not None else None
319
        if self._thres is None:
320 321
            far_str = ""
            if self._criterion == "far" and self._far is not None:
322
                far_str = str(self._far)
323 324 325 326 327
            click.echo(
                "[Min. criterion: %s %s] Threshold on Development set `%s`: %e"
                % (self._criterion.upper(), far_str, title or dev_file, threshold),
                file=self.log_file,
            )
328
        else:
329 330 331 332 333
            click.echo(
                "[Min. criterion: user provided] Threshold on "
                "Development set `%s`: %e" % (dev_file or title, threshold),
                file=self.log_file,
            )
334

335
        res = []
336
        res.append(self._strings(self._numbers(dev_neg, dev_pos, threshold, dev_fta)))
337 338 339 340

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
341 342 343
            res.append(
                self._strings(self._numbers(eval_neg, eval_pos, threshold, eval_fta))
            )
344 345 346 347 348 349
        else:
            res.append(None)

        return res

    def compute(self, idx, input_scores, input_names):
350 351
        """ Compute metrics thresholds and tables (FPR, FNR, precision, recall,
        f1_score) for given system inputs"""
352 353 354
        dev_file = input_names[0]
        title = self._legends[idx] if self._legends is not None else None
        all_metrics = self._get_all_metrics(idx, input_scores, input_names)
355
        fta_dev = float(all_metrics[0][0].replace("%", ""))
356
        if fta_dev > 0.0:
357 358 359 360 361 362
            LOGGER.warn(
                "NaNs scores (%s) were found in %s amd removed",
                all_metrics[0][0],
                dev_file,
            )
        headers = [" " or title, "Development"]
363 364 365 366 367 368 369
        rows = [
            [self.names[0], all_metrics[0][1]],
            [self.names[1], all_metrics[0][2]],
            [self.names[2], all_metrics[0][6]],
            [self.names[3], all_metrics[0][7]],
            [self.names[4], all_metrics[0][8]],
            [self.names[5], all_metrics[0][9]],
370
            [self.names[6], all_metrics[0][10]],
371
        ]
372

373
        if self._eval:
374
            eval_file = input_names[1]
375
            fta_eval = float(all_metrics[1][0].replace("%", ""))
376
            if fta_eval > 0.0:
377 378 379 380 381
                LOGGER.warn(
                    "NaNs scores (%s) were found in %s and removed.",
                    all_metrics[1][0],
                    eval_file,
                )
382 383
            # computes statistics for the eval set based on the threshold a
            # priori
384
            headers.append("Evaluation")
385 386 387 388 389
            rows[0].append(all_metrics[1][1])
            rows[1].append(all_metrics[1][2])
            rows[2].append(all_metrics[1][6])
            rows[3].append(all_metrics[1][7])
            rows[4].append(all_metrics[1][8])
390
            rows[5].append(all_metrics[1][9])
391
            rows[6].append(all_metrics[1][10])
392

393
        click.echo(tabulate(rows, headers, self._tablefmt), file=self.log_file)
394 395

    def end_process(self):
396
        """ Close log file if needed"""
397 398 399
        if self._log is not None:
            self.log_file.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
400

401
class MultiMetrics(Metrics):
402
    """Computes average of metrics based on several protocols (cross
403 404 405 406 407 408 409 410
    validation)

    Attributes
    ----------
    log_file : str
        output stream
    names : tuple
        List of names for the metrics.
411
    """
412

413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
    def __init__(
        self,
        ctx,
        scores,
        evaluation,
        func_load,
        names=(
            "NaNs Rate",
            "False Positive Rate",
            "False Negative Rate",
            "False Accept Rate",
            "False Reject Rate",
            "Half Total Error Rate",
        ),
    ):
428
        super(MultiMetrics, self).__init__(
429 430
            ctx, scores, evaluation, func_load, names=names
        )
431

432
        self.headers = ["Methods"] + list(self.names)
433
        if self._eval:
434
            self.headers.insert(1, self.names[5] + " (dev)")
435 436 437
        self.rows = []

    def _strings(self, metrics):
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
        ftam, fmrm, fnmrm, hterm, farm, frrm, _, _, _, _, _, _, _ = metrics.mean(axis=0)
        ftas, fmrs, fnmrs, hters, fars, frrs, _, _, _, _, _, _, _ = metrics.std(axis=0)
        n_dec = ".%df" % self._decimal
        fta_str = "%s%% (%s%%)" % (format(100 * ftam, n_dec), format(100 * ftas, n_dec))
        fmr_str = "%s%% (%s%%)" % (format(100 * fmrm, n_dec), format(100 * fmrs, n_dec))
        fnmr_str = "%s%% (%s%%)" % (
            format(100 * fnmrm, n_dec),
            format(100 * fnmrs, n_dec),
        )
        far_str = "%s%% (%s%%)" % (format(100 * farm, n_dec), format(100 * fars, n_dec))
        frr_str = "%s%% (%s%%)" % (format(100 * frrm, n_dec), format(100 * frrs, n_dec))
        hter_str = "%s%% (%s%%)" % (
            format(100 * hterm, n_dec),
            format(100 * hters, n_dec),
        )
453 454 455
        return fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str

    def compute(self, idx, input_scores, input_names):
456
        """Computes the average of metrics over several protocols."""
457 458 459 460 461 462
        neg_list, pos_list, fta_list = utils.get_fta_list(input_scores)
        step = 2 if self._eval else 1
        self._dev_metrics = []
        self._thresholds = []
        for i in range(0, len(input_scores), step):
            neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
463 464 465 466 467
            threshold = (
                self.get_thres(self._criterion, neg, pos, self._far)
                if self._thres is None
                else self._thres[idx]
            )
468 469 470 471 472 473 474 475 476
            self._thresholds.append(threshold)
            self._dev_metrics.append(self._numbers(neg, pos, threshold, fta))
        self._dev_metrics = numpy.array(self._dev_metrics)

        if self._eval:
            self._eval_metrics = []
            for i in range(1, len(input_scores), step):
                neg, pos, fta = neg_list[i], pos_list[i], fta_list[i]
                threshold = self._thresholds[i // 2]
477
                self._eval_metrics.append(self._numbers(neg, pos, threshold, fta))
478 479 480 481
            self._eval_metrics = numpy.array(self._eval_metrics)

        title = self._legends[idx] if self._legends is not None else None

482 483 484
        fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = self._strings(
            self._dev_metrics
        )
485 486 487 488

        if self._eval:
            self.rows.append([title, hter_str])
        else:
489 490 491
            self.rows.append(
                [title, fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str]
            )
492 493 494 495

        if self._eval:
            # computes statistics for the eval set based on the threshold a
            # priori
496 497 498
            fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str = self._strings(
                self._eval_metrics
            )
499

500 501 502
            self.rows[-1].extend(
                [fta_str, fmr_str, fnmr_str, far_str, frr_str, hter_str]
            )
503 504

    def end_process(self):
505 506 507
        click.echo(
            tabulate(self.rows, self.headers, self._tablefmt), file=self.log_file
        )
508 509 510
        super(MultiMetrics, self).end_process()


511
class PlotBase(MeasureBase):
512
    """ Base class for plots. Regroup several options and code
513
    shared by the different plots
514
    """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
515

516 517
    def __init__(self, ctx, scores, evaluation, func_load):
        super(PlotBase, self).__init__(ctx, scores, evaluation, func_load)
518 519 520 521 522 523 524
        self._output = ctx.meta.get("output")
        self._points = ctx.meta.get("points", 2000)
        self._split = ctx.meta.get("split")
        self._axlim = ctx.meta.get("axlim")
        self._alpha = ctx.meta.get("alpha")
        self._disp_legend = ctx.meta.get("disp_legend", True)
        self._legend_loc = ctx.meta.get("legend_loc")
525
        self._min_dig = None
526 527
        if "min_far_value" in ctx.meta:
            self._min_dig = int(math.log10(ctx.meta["min_far_value"]))
528
        elif self._axlim is not None and self._axlim[0] is not None:
529 530 531 532 533
            self._min_dig = int(
                math.log10(self._axlim[0]) if self._axlim[0] != 0 else 0
            )
        self._clayout = ctx.meta.get("clayout")
        self._far_at = ctx.meta.get("lines_at")
534 535 536 537
        self._trans_far_val = self._far_at
        if self._far_at is not None:
            self._eval_points = {line: [] for line in self._far_at}
            self._lines_val = []
538 539 540 541
        self._print_fn = ctx.meta.get("show_fn", True)
        self._x_rotation = ctx.meta.get("x_rotation")
        if "style" in ctx.meta:
            mpl.style.use(ctx.meta["style"])
542
        self._nb_figs = 2 if self._eval and self._split else 1
543
        self._colors = utils.get_colors(self.n_systems)
544 545 546
        self._line_linestyles = ctx.meta.get("line_styles", False)
        self._linestyles = utils.get_linestyles(self.n_systems, self._line_linestyles)
        self._titles = ctx.meta.get("titles", []) * 2
547
        # for compatibility
548
        self._title = ctx.meta.get("title")
549 550 551
        if not self._titles and self._title is not None:
            self._titles = [self._title] * 2

552 553 554
        self._x_label = ctx.meta.get("x_label")
        self._y_label = ctx.meta.get("y_label")
        self._grid_color = "silver"
555 556 557 558
        self._pdf_page = None
        self._end_setup_plot = True

    def init_process(self):
559 560 561 562 563 564 565 566 567
        """ Open pdf and set axis font size if provided """
        if not hasattr(matplotlib, "backends"):
            matplotlib.use("pdf")

        self._pdf_page = (
            self._ctx.meta["PdfPages"]
            if "PdfPages" in self._ctx.meta
            else PdfPages(self._output)
        )
568

569
        for i in range(self._nb_figs):
570
            fs = self._ctx.meta.get("figsize")
571 572
            fig = mpl.figure(i + 1, figsize=fs)
            fig.set_constrained_layout(self._clayout)
573
            fig.clear()
574 575

    def end_process(self):
576 577
        """ Set title, legend, axis labels, grid colors, save figures, drow
        lines and close pdf if needed """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
578
        # draw vertical lines
579 580 581
        if self._far_at is not None:
            for (line, line_trans) in zip(self._far_at, self._trans_far_val):
                mpl.figure(1)
582
                mpl.plot([line_trans, line_trans], [-100.0, 100.0], "--", color="black")
583 584 585 586
                if self._eval and self._split:
                    mpl.figure(2)
                    x_values = [i for i, _ in self._eval_points[line]]
                    y_values = [j for _, j in self._eval_points[line]]
587
                    sort_indice = sorted(range(len(x_values)), key=x_values.__getitem__)
588 589
                    x_values = [x_values[i] for i in sort_indice]
                    y_values = [y_values[i] for i in sort_indice]
590
                    mpl.plot(x_values, y_values, "--", color="black")
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
591
        # only for plots
592 593 594
        if self._end_setup_plot:
            for i in range(self._nb_figs):
                fig = mpl.figure(i + 1)
595 596
                title = "" if not self._titles else self._titles[i]
                mpl.title(title if title.replace(" ", "") else "")
597 598 599
                mpl.xlabel(self._x_label)
                mpl.ylabel(self._y_label)
                mpl.grid(True, color=self._grid_color)
600 601
                if self._disp_legend:
                    mpl.legend(loc=self._legend_loc)
602
                self._set_axis()
603 604 605
                mpl.xticks(rotation=self._x_rotation)
                self._pdf_page.savefig(fig)

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
606
        # do not want to close PDF when running evaluate
607 608 609
        if "PdfPages" in self._ctx.meta and (
            "closef" not in self._ctx.meta or self._ctx.meta["closef"]
        ):
610 611
            self._pdf_page.close()

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
612
    # common protected functions
613

614
    def _label(self, base, idx):
615 616
        if self._legends is not None and len(self._legends) > idx:
            return self._legends[idx]
617
        if self.n_systems > 1:
618 619
            return base + (" %d" % (idx + 1))
        return base
620

621
    def _set_axis(self):
622
        if self._axlim is not None:
623
            mpl.axis(self._axlim)
624

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
625

626
class Roc(PlotBase):
627
    """ Handles the plotting of ROC"""
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
628

629 630
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Roc, self).__init__(ctx, scores, evaluation, func_load)
631 632 633 634 635 636 637
        self._titles = self._titles or ["ROC dev.", "ROC eval."]
        self._x_label = self._x_label or "FPR"
        self._semilogx = ctx.meta.get("semilogx", True)
        self._tpr = ctx.meta.get("tpr", True)
        dflt_y_label = "TPR" if self._tpr else "FNR"
        self._y_label = self._y_label or dflt_y_label
        best_legend = "lower right" if self._semilogx else "upper right"
638
        self._legend_loc = self._legend_loc or best_legend
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
639
        # custom defaults
640
        if self._axlim is None:
641
            self._axlim = [None, None, -0.05, 1.05]
642
        self._min_dig = -4 if self._min_dig is None else self._min_dig
643

644
    def compute(self, idx, input_scores, input_names):
645 646
        """ Plot ROC for dev and eval data using
        :py:func:`bob.measure.plot.roc`"""
647 648
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
649 650
        dev_file = input_names[0]
        if self._eval:
651
            eval_neg, eval_pos = neg_list[1], pos_list[1]
652 653
            eval_file = input_names[1]

654
        mpl.figure(1)
655
        if self._eval:
656
            LOGGER.info("ROC dev. curve using %s", dev_file)
657
            plot.roc(
658 659
                dev_neg,
                dev_pos,
660
                npoints=self._points,
661 662
                semilogx=self._semilogx,
                tpr=self._tpr,
663
                min_far=self._min_dig,
664 665 666
                color=self._colors[idx],
                linestyle=self._linestyles[idx],
                label=self._label("dev", idx),
667
                alpha=self._alpha,
668 669 670 671
            )
            if self._split:
                mpl.figure(2)

672
            linestyle = "--" if not self._split else self._linestyles[idx]
673
            LOGGER.info("ROC eval. curve using %s", eval_file)
674
            plot.roc(
675 676 677
                eval_neg,
                eval_pos,
                linestyle=linestyle,
678
                npoints=self._points,
679 680
                semilogx=self._semilogx,
                tpr=self._tpr,
681
                min_far=self._min_dig,
682
                color=self._colors[idx],
683
                label=self._label("eval.", idx),
684
                alpha=self._alpha,
685
            )
686
            if self._far_at is not None:
687
                from .. import fprfnr
688

689
                for line in self._far_at:
690
                    thres_line = far_threshold(dev_neg, dev_pos, line)
691 692
                    eval_fmr, eval_fnmr = fprfnr(eval_neg, eval_pos, thres_line)
                    if self._tpr:
693
                        eval_fnmr = 1 - eval_fnmr
694 695
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
696
        else:
697
            LOGGER.info("ROC dev. curve using %s", dev_file)
698
            plot.roc(
699 700
                dev_neg,
                dev_pos,
701
                npoints=self._points,
702 703
                semilogx=self._semilogx,
                tpr=self._tpr,
704
                min_far=self._min_dig,
705 706 707
                color=self._colors[idx],
                linestyle=self._linestyles[idx],
                label=self._label("dev", idx),
708
                alpha=self._alpha,
709 710
            )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
711

712
class Det(PlotBase):
713
    """ Handles the plotting of DET """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
714

715 716
    def __init__(self, ctx, scores, evaluation, func_load):
        super(Det, self).__init__(ctx, scores, evaluation, func_load)
717 718 719 720
        self._titles = self._titles or ["DET dev.", "DET eval."]
        self._x_label = self._x_label or "FPR (%)"
        self._y_label = self._y_label or "FNR (%)"
        self._legend_loc = self._legend_loc or "upper right"
721 722
        if self._far_at is not None:
            self._trans_far_val = [ppndf(float(k)) for k in self._far_at]
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
723
        # custom defaults here
724 725
        if self._x_rotation is None:
            self._x_rotation = 50
726

727 728 729 730 731 732
        if self._axlim is None:
            self._axlim = [0.01, 99, 0.01, 99]

        if self._min_dig is not None:
            self._axlim[0] = math.pow(10, self._min_dig) * 100

733 734
        self._min_dig = -4 if self._min_dig is None else self._min_dig

735
    def compute(self, idx, input_scores, input_names):
736 737
        """ Plot DET for dev and eval data using
        :py:func:`bob.measure.plot.det`"""
738 739
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
740 741
        dev_file = input_names[0]
        if self._eval:
742
            eval_neg, eval_pos = neg_list[1], pos_list[1]
743 744
            eval_file = input_names[1]

745
        mpl.figure(1)
746
        if self._eval and eval_neg is not None:
747
            LOGGER.info("DET dev. curve using %s", dev_file)
748
            plot.det(
749 750 751 752
                dev_neg,
                dev_pos,
                self._points,
                min_far=self._min_dig,
753
                color=self._colors[idx],
754
                linestyle=self._linestyles[idx],
755
                label=self._label("dev.", idx),
756
                alpha=self._alpha,
757 758 759
            )
            if self._split:
                mpl.figure(2)
760
            linestyle = "--" if not self._split else self._linestyles[idx]
761
            LOGGER.info("DET eval. curve using %s", eval_file)
762
            plot.det(
763 764 765 766
                eval_neg,
                eval_pos,
                self._points,
                min_far=self._min_dig,
767
                color=self._colors[idx],
768
                linestyle=linestyle,
769
                label=self._label("eval.", idx),
770
                alpha=self._alpha,
771
            )
772 773
            if self._far_at is not None:
                from .. import farfrr
774

775 776
                for line in self._far_at:
                    thres_line = far_threshold(dev_neg, dev_pos, line)
777
                    eval_fmr, eval_fnmr = farfrr(eval_neg, eval_pos, thres_line)
778 779 780
                    eval_fmr, eval_fnmr = ppndf(eval_fmr), ppndf(eval_fnmr)
                    mpl.scatter(eval_fmr, eval_fnmr, c=self._colors[idx], s=30)
                    self._eval_points[line].append((eval_fmr, eval_fnmr))
781
        else:
782
            LOGGER.info("DET dev. curve using %s", dev_file)
783
            plot.det(
784 785 786 787
                dev_neg,
                dev_pos,
                self._points,
                min_far=self._min_dig,
788
                color=self._colors[idx],
789
                linestyle=self._linestyles[idx],
790
                label=self._label("dev.", idx),
791
                alpha=self._alpha,
792 793
            )

794
    def _set_axis(self):
795
        plot.det_axis(self._axlim)
796

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
797

798
class Epc(PlotBase):
799
    """ Handles the plotting of EPC """
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
800

801
    def __init__(self, ctx, scores, evaluation, func_load, hter="HTER"):
802
        super(Epc, self).__init__(ctx, scores, evaluation, func_load)
803
        if self._min_arg != 2:
804
            raise click.UsageError("EPC requires dev. and eval. score files")
805 806 807 808
        self._titles = self._titles or ["EPC"] * 2
        self._x_label = self._x_label or r"$\alpha$"
        self._y_label = self._y_label or hter + " (%)"
        self._legend_loc = self._legend_loc or "upper center"
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
809
        self._eval = True  # always eval data with EPC
810
        self._split = False
811
        self._nb_figs = 1
812
        self._far_at = None
813

814
    def compute(self, idx, input_scores, input_names):
815
        """ Plot EPC using :py:func:`bob.measure.plot.epc` """
816 817
        neg_list, pos_list, _ = utils.get_fta_list(input_scores)
        dev_neg, dev_pos = neg_list[0], pos_list[0]
818 819
        dev_file = input_names[0]
        if self._eval:
820
            eval_neg, eval_pos = neg_list[1], pos_list[1]
821 822
            eval_file = input_names[1]

823
        LOGGER.info("EPC using %s", dev_file + "_" + eval_file)
824
        plot.epc(
825 826 827 828 829 830 831 832
            dev_neg,
            dev_pos,
            eval_neg,
            eval_pos,
            self._points,
            color=self._colors[idx],
            linestyle=self._linestyles[idx],
            label=self._label("curve", idx),
833
            alpha=self._alpha,
834 835
        )

Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
836

837 838 839 840 841 842 843 844 845 846 847 848 849 850
class GridSubplot(PlotBase):
    """A base class for plots that contain subplots and legends. This is needed
    because "constrained_layout will not work on subplots created via the
    subplot command."

    To use this class, use `create_subplot` in `compute` each time you need a
    new axis. and call `finalize_one_page` in `compute` when a page is finished
    rendering.
    """

    def __init__(self, ctx, scores, evaluation, func_load):
        super(GridSubplot, self).__init__(ctx, scores, evaluation, func_load)

        # Check legend
851 852 853 854 855 856 857 858
        self._legend_loc = self._legend_loc or "upper center"
        if self._legend_loc == "best":
            self._legend_loc = "upper center"
        if "upper" not in self._legend_loc and "lower" not in self._legend_loc:
            raise ValueError(
                "Only best, (upper *), and (lower-*) legend " "locations are supported!"
            )
        if "up" in self._legend_loc:
859 860 861 862 863 864 865
            self._legend_grid_axis_number = 0
            self._grid_axis_offset = 1
        else:
            self._legend_grid_axis_number = -1
            self._grid_axis_offset = 0

        # subplot grid
866 867
        self._nrows = ctx.meta.get("n_row", 1)
        self._ncols = ctx.meta.get("n_col", 1)
868 869 870 871 872 873 874 875 876 877 878 879 880 881

        # GridSpec artificial rows and cols multipliers
        self._row_times = 8
        self._col_times = 2

    def init_process(self):
        super(GridSubplot, self).init_process()
        self._create_grid_spec()

    def _create_grid_spec(self):
        # create a compatible GridSpec
        self._gs = gridspec.GridSpec(
            self._nrows * self._row_times + 1,
            self._ncols * self._col_times,
882 883
            figure=mpl.gcf(),
        )
884

885
    def create_subplot(self, n, shared_axis=None):
886 887 888 889
        i, j = numpy.unravel_index(n, (self._nrows, self._ncols))
        i1 = i * self._row_times + self._grid_axis_offset
        i2 = (i + 1) * self._row_times + self._grid_axis_offset
        j1, j2 = j * self._col_times, (j + 1) * self._col_times
890
        axis = mpl.gcf().add_subplot(self._gs[i1:i2, j1:j2], sharex=shared_axis)
891 892 893 894 895
        return axis

    def finalize_one_page(self):
        # print legend on the page
        self.plot_legends()
896
        self._pdf_page.savefig(bbox_inches="tight")
897 898 899 900 901
        mpl.clf()
        mpl.figure()
        self._create_grid_spec()

    def plot_legends(self):
902
        """ Print legend on current page"""
903 904 905 906 907 908 909 910 911 912 913 914
        lines = []
        labels = []
        for ax in mpl.gcf().get_axes():
            ali, ala = ax.get_legend_handles_labels()
            # avoid duplicates in legend
            for li, la in zip(ali, ala):
                if la not in labels:
                    lines.append(li)
                    labels.append(la)

        if self._disp_legend:
            # create legend on the top or bottom axis
915
            ax = mpl.gcf().add_subplot(self._gs[self._legend_grid_axis_number, :])
916 917 918 919 920 921 922 923
            # right, left, or center
            loc = self._legend_loc.split()[1]
            ax.legend(lines, labels, loc=loc, ncol=self._nlegends)
            # don't show its axis
            ax.set_axis_off()


class Hist(GridSubplot):
924
    """ Functional base class for histograms"""
Amir MOHAMMADI's avatar
lint  
Amir MOHAMMADI committed
925

926
    def __init__(self, ctx, scores, evaluation, func_load, nhist_per_system=2):
927
        super(Hist, self).__init__(ctx, scores, evaluation, func_load)
928
        self._nbins = ctx.meta.get("n_bins", ["doane"])
929 930
        self._nhist_per_system = nhist_per_system
        self._nbins = check_list_value(
931 932 933 934 935
            self._nbins, nhist_per_system, "n_bins", "histograms"
        )
        self._thres = ctx.meta.get("thres")
        self._thres = check_list_value(self._thres, self.n_systems, "thresholds")
        self._criterion = ctx.meta.get("criterion")
936
        # no vertical (threshold) is displayed
937
        self._no_line = ctx.meta.get("no_line", False)
938
        # do not display dev histo
939
        self._hide_dev = ctx.meta.get("hide_dev", False)
940
        if self._hide_dev and not self._eval:
941
            raise click.BadParameter("You can only use --hide-dev along with --eval")
942
        # dev hist are displayed next to eval hist
943
        self._nrows *= 1 if self._hide_dev or not self._eval else 2
944
        self._nlegends = ctx.meta.get("legends_ncol", 3)
945

946
        # number of subplot on one page
947
        self._step_print = int(self._nrows * self._ncols)
948 949 950
        self._title_base = "Scores"
        self._y_label = self._y_label or "Probability density"
        self._x_label = self._x_label or "Score values"
951
        self._end_setup_plot = False
952
        # overide _titles of PlotBase
953
        self._titles = ctx.meta.get("titles", []) * 2
954

955
    def compute(self, idx, input_scores, input_names):
956 957 958 959
        """ Draw histograms of negative and positive scores."""
        dev_neg, dev_pos, eval_neg, eval_pos, threshold = self._get_neg_pos_thres(
            idx, input_scores, input_names
        )
960

961 962
        # keep id of the current system
        sys = idx
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
963
        # if the id of the current system does not match the id of the plot,
964 965 966 967 968
        # change it
        if not self._hide_dev and self._eval:
            row = int(idx / self._ncols) * 2
            col = idx % self._ncols
            idx = col + self._ncols * row
969

970 971
        dev_axis = None

972
        if not self._hide_dev or not self._eval:
973 974
            dev_axis = self._print_subplot(
                idx, sys, dev_neg, dev_pos, threshold, not self._no_line, False,
975
            )