cross.py 5.52 KB
Newer Older
1 2 3
"""Prints Cross-db metrics analysis
"""
import click
4
import json
5
import jinja2
6
import logging
7 8
import math
import os
9
import yaml
10
from bob.bio.base.score.load import split
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
11 12
from bob.extension.scripts.click_helper import (
    verbosity_option, bool_option, log_parameters)
13 14 15 16
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
17
from tabulate import tabulate
18 19 20 21

logger = logging.getLogger(__name__)


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
22 23
@click.command(epilog='''\b
Examples:
24 25
  $ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' \
    -td replaymobile -d replaymobile -p grandtest -d oulunpu -p Protocol_1 \
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
26 27 28 29 30
    -a replaymobile_frame-diff-svm \
    -a replaymobile_qm-svm-64 \
    -a replaymobile_lbp-svm-64 \
    > replaymobile.rst &
''')
31 32 33 34 35 36 37 38 39 40
@click.argument('score_jinja_template')
@click.option('-d', '--database', 'databases', multiple=True, required=True,
              show_default=True,
              help='Names of the evaluation databases')
@click.option('-p', '--protocol', 'protocols', multiple=True, required=True,
              show_default=True,
              help='Names of the protocols of the evaluation databases')
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
              show_default=True,
              help='Names of the algorithms')
41 42 43 44
@click.option('-n', '--names', type=click.File('r'),
              help='Name of algorithms to show in the table. Provide a path '
              'to a json file maps algorithm names to names that you want to '
              'see in the table.')
45 46 47 48
@click.option('-td', '--train-database', required=True,
              help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
              default=['train', 'dev', 'eval'])
49
@bool_option('sort', 's', 'whether the table should be sorted.', True)
50
@common_options.table_option()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
51
@common_options.output_log_metric_option()
52 53 54
@verbosity_option()
@click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
55
          names, train_database, groups, sort, **kwargs):
56 57
    """Cross-db analysis metrics
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
58
    log_parameters(logger)
59

60 61
    names = {} if names is None else json.load(names)

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    env = jinja2.Environment(undefined=jinja2.StrictUndefined)

    data = {
        'evaluation': [{'database': db, 'protocol': proto}
                       for db, proto in zip(databases, protocols)],
        'algorithm': algorithms,
        'group': groups,
    }

    metrics = {}

    for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
        logger.debug(variables)

        score_path = env.from_string(score_jinja_template).render(variables)
        logger.debug(score_path)

        database, protocol, algorithm, group = \
            variables['evaluation']['database'], \
            variables['evaluation']['protocol'], \
            variables['algorithm'], variables['group']

        # if algorithm name does not have train_database name in it.
        if train_database not in algorithm and database != train_database:
            score_path = score_path.replace(
                algorithm, database + '_' + algorithm)

        if not os.path.exists(score_path):
            metrics[(database, protocol, algorithm, group)] = \
                (float('nan'), ) * 5
            continue

        (neg, pos), fta = get_fta(split(score_path))

        if group == 'eval':
            threshold = metrics[(database, protocol, algorithm, 'dev')][1]
        else:
99 100 101 102 103
            try:
                threshold = eer_threshold(neg, pos)
            except RuntimeError:
                logger.error("Something wrong with {}".format(score_path))
                raise
104 105 106 107 108 109 110 111 112

        far, frr = farfrr(neg, pos, threshold)
        hter = (far + frr) / 2

        metrics[(database, protocol, algorithm, group)] = \
            (hter, threshold, fta, far, frr)

    logger.debug('metrics: %s', metrics)

113 114 115 116 117 118 119 120 121 122 123 124
    headers = ["Algorithms"]
    for db in databases:
        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
    rows = []

    # sort the algorithms based on HTER test, EER dev, EER train
    if sort:
        train_protocol = protocols[databases.index(train_database)]

        def sort_key(alg):
            r = []
            for grp in ('eval', 'dev', 'train'):
125 126
                hter = metrics[(train_database, train_protocol, alg, group)][0]
                r.append(1 if math.isnan(hter) else hter)
127 128
            return tuple(r)
        algorithms = sorted(algorithms, key=sort_key)
129 130

    for algorithm in algorithms:
131 132 133 134
        name = algorithm.replace(train_database + '_', '')
        name = name.replace(train_protocol + '_', '')
        name = names.get(name, name)
        rows.append([name])
135
        for database, protocol in zip(databases, protocols):
136 137 138 139 140 141 142 143 144 145
            cell = []
            for group in groups:
                hter, threshold, fta, far, frr = metrics[(
                    database, protocol, algorithm, group)]
                if group == 'eval':
                    cell += [far, frr, hter]
                else:
                    cell += [hter]
            cell = [round(c * 100, 1) for c in cell]
            rows[-1].extend(cell)
146 147 148 149

    title = ' Trained on {} '.format(train_database)
    title_line = '\n' + '=' * len(title) + '\n'
    click.echo(title_line + title + title_line, file=ctx.meta['log'])
150
    click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
151
               file=ctx.meta['log'])