cross.py 4.94 KB
Newer Older
1
2
3
"""Prints Cross-db metrics analysis
"""
import click
4
import jinja2
5
import logging
6
7
import math
import os
8
import yaml
9
from bob.bio.base.score.load import split
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10
11
from bob.extension.scripts.click_helper import (
    verbosity_option, bool_option, log_parameters)
12
13
14
15
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
16
from tabulate import tabulate
17
18
19
20

logger = logging.getLogger(__name__)


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
21
22
23
24
25
26
27
28
29
@click.command(epilog='''\b
Examples:

  $ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' -td replaymobile -d replaymobile -p grandtest -d oulunpu -p Protocol_1 \
    -a replaymobile_frame-diff-svm \
    -a replaymobile_qm-svm-64 \
    -a replaymobile_lbp-svm-64 \
    > replaymobile.rst &
''')
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@click.argument('score_jinja_template')
@click.option('-d', '--database', 'databases', multiple=True, required=True,
              show_default=True,
              help='Names of the evaluation databases')
@click.option('-p', '--protocol', 'protocols', multiple=True, required=True,
              show_default=True,
              help='Names of the protocols of the evaluation databases')
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
              show_default=True,
              help='Names of the algorithms')
@click.option('-td', '--train-database', required=True,
              help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
              default=['train', 'dev', 'eval'])
44
@bool_option('sort', 's', 'whether the table should be sorted.', True)
45
@common_options.table_option()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
46
@common_options.output_log_metric_option()
47
48
49
@verbosity_option()
@click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
50
          train_database, groups, sort, **kwargs):
51
52
    """Cross-db analysis metrics
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
53
    log_parameters(logger)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    env = jinja2.Environment(undefined=jinja2.StrictUndefined)

    data = {
        'evaluation': [{'database': db, 'protocol': proto}
                       for db, proto in zip(databases, protocols)],
        'algorithm': algorithms,
        'group': groups,
    }

    metrics = {}

    for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
        logger.debug(variables)

        score_path = env.from_string(score_jinja_template).render(variables)
        logger.debug(score_path)

        database, protocol, algorithm, group = \
            variables['evaluation']['database'], \
            variables['evaluation']['protocol'], \
            variables['algorithm'], variables['group']

        # if algorithm name does not have train_database name in it.
        if train_database not in algorithm and database != train_database:
            score_path = score_path.replace(
                algorithm, database + '_' + algorithm)

        if not os.path.exists(score_path):
            metrics[(database, protocol, algorithm, group)] = \
                (float('nan'), ) * 5
            continue

        (neg, pos), fta = get_fta(split(score_path))

        if group == 'eval':
            threshold = metrics[(database, protocol, algorithm, 'dev')][1]
        else:
            threshold = eer_threshold(neg, pos)

        far, frr = farfrr(neg, pos, threshold)
        hter = (far + frr) / 2

        metrics[(database, protocol, algorithm, group)] = \
            (hter, threshold, fta, far, frr)

    logger.debug('metrics: %s', metrics)

102
103
104
105
106
107
108
109
110
111
112
113
    headers = ["Algorithms"]
    for db in databases:
        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
    rows = []

    # sort the algorithms based on HTER test, EER dev, EER train
    if sort:
        train_protocol = protocols[databases.index(train_database)]

        def sort_key(alg):
            r = []
            for grp in ('eval', 'dev', 'train'):
114
115
                hter = metrics[(train_database, train_protocol, alg, group)][0]
                r.append(1 if math.isnan(hter) else hter)
116
117
            return tuple(r)
        algorithms = sorted(algorithms, key=sort_key)
118
119

    for algorithm in algorithms:
120
        rows.append([algorithm.replace(train_database + '_', '')])
121
        for database, protocol in zip(databases, protocols):
122
123
124
125
126
127
128
129
130
131
            cell = []
            for group in groups:
                hter, threshold, fta, far, frr = metrics[(
                    database, protocol, algorithm, group)]
                if group == 'eval':
                    cell += [far, frr, hter]
                else:
                    cell += [hter]
            cell = [round(c * 100, 1) for c in cell]
            rows[-1].extend(cell)
132
133
134
135

    title = ' Trained on {} '.format(train_database)
    title_line = '\n' + '=' * len(title) + '\n'
    click.echo(title_line + title + title_line, file=ctx.meta['log'])
136
    click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
137
               file=ctx.meta['log'])