cross.py 6.28 KB
Newer Older
1 2 3
"""Prints Cross-db metrics analysis
"""
import click
4
import json
5
import jinja2
6
import logging
7 8
import math
import os
9
import yaml
10
from bob.bio.base.score.load import load_score, get_negatives_positives
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
11
from bob.extension.scripts.click_helper import (
12 13 14 15
    verbosity_option,
    bool_option,
    log_parameters,
)
16 17 18 19
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
20
from tabulate import tabulate
21 22 23 24

logger = logging.getLogger(__name__)


25 26
@click.command(
    epilog="""\b
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
27
Examples:
28
  $ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' \
29 30 31 32 33 34
    -td replaymobile \
    -d replaymobile -p grandtest \
    -d oulunpu -p Protocol_1 \
    -a replaymobile_grandtest_frame-diff-svm \
    -a replaymobile_grandtest_qm-svm-64 \
    -a replaymobile_grandtest_lbp-svm-64 \
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
35
    > replaymobile.rst &
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
"""
)
@click.argument("score_jinja_template")
@click.option(
    "-d",
    "--database",
    "databases",
    multiple=True,
    required=True,
    show_default=True,
    help="Names of the evaluation databases",
)
@click.option(
    "-p",
    "--protocol",
    "protocols",
    multiple=True,
    required=True,
    show_default=True,
    help="Names of the protocols of the evaluation databases",
)
@click.option(
    "-a",
    "--algorithm",
    "algorithms",
    multiple=True,
    required=True,
    show_default=True,
    help="Names of the algorithms",
)
@click.option(
    "-n",
    "--names",
    type=click.File("r"),
    help="Name of algorithms to show in the table. Provide a path "
    "to a json file maps algorithm names to names that you want to "
    "see in the table.",
)
@click.option(
    "-td",
    "--train-database",
    required=True,
    help="The database that was used to train the algorithms.",
)
@click.option(
    "-pn",
    "--pai-names",
    type=click.File("r"),
    help="Name of PAIs to compute the errors per PAI. Provide a path "
    "to a json file maps attack_type in scores to PAIs that you want to "
    "see in the table.",
)
@click.option(
    "-g",
    "--group",
    "groups",
    multiple=True,
    show_default=True,
    default=["train", "dev", "eval"],
)
@bool_option("sort", "s", "whether the table should be sorted.", True)
97
@common_options.table_option()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
98
@common_options.output_log_metric_option()
99 100
@verbosity_option()
@click.pass_context
101 102 103 104 105 106 107 108 109 110 111 112 113 114
def cross(
    ctx,
    score_jinja_template,
    databases,
    protocols,
    algorithms,
    names,
    train_database,
    pai_names,
    groups,
    sort,
    verbose,
    **kwargs
):
115 116
    """Cross-db analysis metrics
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
117
    log_parameters(logger)
118

119 120
    names = {} if names is None else json.load(names)

121 122 123
    env = jinja2.Environment(undefined=jinja2.StrictUndefined)

    data = {
124 125 126 127 128 129
        "evaluation": [
            {"database": db, "protocol": proto}
            for db, proto in zip(databases, protocols)
        ],
        "algorithm": algorithms,
        "group": groups,
130 131 132 133 134 135 136 137
    }

    metrics = {}

    for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
        logger.debug(variables)

        score_path = env.from_string(score_jinja_template).render(variables)
138
        logger.info(score_path)
139

140 141 142 143 144 145
        database, protocol, algorithm, group = (
            variables["evaluation"]["database"],
            variables["evaluation"]["protocol"],
            variables["algorithm"],
            variables["group"],
        )
146 147 148

        # if algorithm name does not have train_database name in it.
        if train_database not in algorithm and database != train_database:
149 150
            score_path = score_path.replace(algorithm, database + "_" + algorithm)
            logger.info("Score path changed to: %s", score_path)
151 152

        if not os.path.exists(score_path):
153
            metrics[(database, protocol, algorithm, group)] = (float("nan"),) * 5
154 155
            continue

156 157 158
        scores = load_score(score_path)
        neg, pos = get_negatives_positives(scores)
        (neg, pos), fta = get_fta((neg, pos))
159

160 161
        if group == "eval":
            threshold = metrics[(database, protocol, algorithm, "dev")][1]
162
        else:
163 164 165 166 167
            try:
                threshold = eer_threshold(neg, pos)
            except RuntimeError:
                logger.error("Something wrong with {}".format(score_path))
                raise
168 169 170 171

        far, frr = farfrr(neg, pos, threshold)
        hter = (far + frr) / 2

172 173 174 175 176 177 178
        metrics[(database, protocol, algorithm, group)] = (
            hter,
            threshold,
            fta,
            far,
            frr,
        )
179

180
    logger.debug("metrics: %s", metrics)
181

182 183 184 185 186 187 188 189 190 191 192
    headers = ["Algorithms"]
    for db in databases:
        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
    rows = []

    # sort the algorithms based on HTER test, EER dev, EER train
    if sort:
        train_protocol = protocols[databases.index(train_database)]

        def sort_key(alg):
            r = []
193
            for grp in ("eval", "dev", "train"):
194 195
                hter = metrics[(train_database, train_protocol, alg, group)][0]
                r.append(1 if math.isnan(hter) else hter)
196
            return tuple(r)
197

198
        algorithms = sorted(algorithms, key=sort_key)
199 200

    for algorithm in algorithms:
201 202
        name = algorithm.replace(train_database + "_", "")
        name = name.replace(train_protocol + "_", "")
203 204
        name = names.get(name, name)
        rows.append([name])
205
        for database, protocol in zip(databases, protocols):
206 207
            cell = []
            for group in groups:
208 209 210 211
                hter, threshold, fta, far, frr = metrics[
                    (database, protocol, algorithm, group)
                ]
                if group == "eval":
212 213 214 215 216
                    cell += [far, frr, hter]
                else:
                    cell += [hter]
            cell = [round(c * 100, 1) for c in cell]
            rows[-1].extend(cell)
217

218 219 220 221 222 223 224 225 226 227 228
    title = " Trained on {} ".format(train_database)
    title_line = "\n" + "=" * len(title) + "\n"
    # open log file for writing if any
    ctx.meta["log"] = (
        ctx.meta["log"] if ctx.meta["log"] is None else open(ctx.meta["log"], "w")
    )
    click.echo(title_line + title + title_line, file=ctx.meta["log"])
    click.echo(
        tabulate(rows, headers, ctx.meta["tablefmt"], floatfmt=".1f"),
        file=ctx.meta["log"],
    )