cross.py 4.98 KB
Newer Older
1 2 3
"""Prints Cross-db metrics analysis
"""
import click
4
import jinja2
5
import logging
6 7
import math
import os
8
import yaml
9
from bob.bio.base.score.load import split
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10 11
from bob.extension.scripts.click_helper import (
    verbosity_option, bool_option, log_parameters)
12 13 14 15
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
16
from tabulate import tabulate
17 18 19 20

logger = logging.getLogger(__name__)


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
21 22 23 24 25 26 27 28 29
@click.command(epilog='''\b
Examples:

  $ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' -td replaymobile -d replaymobile -p grandtest -d oulunpu -p Protocol_1 \
    -a replaymobile_frame-diff-svm \
    -a replaymobile_qm-svm-64 \
    -a replaymobile_lbp-svm-64 \
    > replaymobile.rst &
''')
30 31 32 33 34 35 36 37 38 39 40 41 42 43
@click.argument('score_jinja_template')
@click.option('-d', '--database', 'databases', multiple=True, required=True,
              show_default=True,
              help='Names of the evaluation databases')
@click.option('-p', '--protocol', 'protocols', multiple=True, required=True,
              show_default=True,
              help='Names of the protocols of the evaluation databases')
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
              show_default=True,
              help='Names of the algorithms')
@click.option('-td', '--train-database', required=True,
              help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
              default=['train', 'dev', 'eval'])
44
@bool_option('sort', 's', 'whether the table should be sorted.', True)
45
@common_options.table_option()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
46
@common_options.output_log_metric_option()
47 48 49
@verbosity_option()
@click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
50
          train_database, groups, sort, **kwargs):
51 52
    """Cross-db analysis metrics
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
53
    log_parameters(logger)
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101

    env = jinja2.Environment(undefined=jinja2.StrictUndefined)

    data = {
        'evaluation': [{'database': db, 'protocol': proto}
                       for db, proto in zip(databases, protocols)],
        'algorithm': algorithms,
        'group': groups,
    }

    metrics = {}

    for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
        logger.debug(variables)

        score_path = env.from_string(score_jinja_template).render(variables)
        logger.debug(score_path)

        database, protocol, algorithm, group = \
            variables['evaluation']['database'], \
            variables['evaluation']['protocol'], \
            variables['algorithm'], variables['group']

        # if algorithm name does not have train_database name in it.
        if train_database not in algorithm and database != train_database:
            score_path = score_path.replace(
                algorithm, database + '_' + algorithm)

        if not os.path.exists(score_path):
            metrics[(database, protocol, algorithm, group)] = \
                (float('nan'), ) * 5
            continue

        (neg, pos), fta = get_fta(split(score_path))

        if group == 'eval':
            threshold = metrics[(database, protocol, algorithm, 'dev')][1]
        else:
            threshold = eer_threshold(neg, pos)

        far, frr = farfrr(neg, pos, threshold)
        hter = (far + frr) / 2

        metrics[(database, protocol, algorithm, group)] = \
            (hter, threshold, fta, far, frr)

    logger.debug('metrics: %s', metrics)

102 103 104 105 106 107 108 109 110 111 112 113
    headers = ["Algorithms"]
    for db in databases:
        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
    rows = []

    # sort the algorithms based on HTER test, EER dev, EER train
    if sort:
        train_protocol = protocols[databases.index(train_database)]

        def sort_key(alg):
            r = []
            for grp in ('eval', 'dev', 'train'):
114 115
                hter = metrics[(train_database, train_protocol, alg, group)][0]
                r.append(1 if math.isnan(hter) else hter)
116 117
            return tuple(r)
        algorithms = sorted(algorithms, key=sort_key)
118 119

    for algorithm in algorithms:
120
        rows.append([algorithm.replace(train_database + '_', '').replace(train_protocol + '_', '')])
121
        for database, protocol in zip(databases, protocols):
122 123 124 125 126 127 128 129 130 131
            cell = []
            for group in groups:
                hter, threshold, fta, far, frr = metrics[(
                    database, protocol, algorithm, group)]
                if group == 'eval':
                    cell += [far, frr, hter]
                else:
                    cell += [hter]
            cell = [round(c * 100, 1) for c in cell]
            rows[-1].extend(cell)
132 133 134 135

    title = ' Trained on {} '.format(train_database)
    title_line = '\n' + '=' * len(title) + '\n'
    click.echo(title_line + title + title_line, file=ctx.meta['log'])
136
    click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
137
               file=ctx.meta['log'])