cross.py 4.74 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
"""Prints Cross-db metrics analysis
"""
import os
import click
import logging
import yaml
import jinja2
from tabulate import tabulate
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
12
from bob.extension.scripts.click_helper import verbosity_option, bool_option
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from bob.bio.base.score.load import split
from gridtk.generator import expand

logger = logging.getLogger(__name__)


@click.command(context_settings=dict(token_normalize_func=lambda x: x.lower()))
@click.argument('score_jinja_template')
@click.option('-d', '--database', 'databases', multiple=True, required=True,
              show_default=True,
              help='Names of the evaluation databases')
@click.option('-p', '--protocol', 'protocols', multiple=True, required=True,
              show_default=True,
              help='Names of the protocols of the evaluation databases')
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
              show_default=True,
              help='Names of the algorithms')
@click.option('-td', '--train-database', required=True,
              help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
              default=['train', 'dev', 'eval'])
34
@bool_option('sort', 's', 'whether the table should be sorted.', True)
35
@common_options.table_option()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
36
@common_options.output_log_metric_option()
37 38 39
@verbosity_option()
@click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
40
          train_database, groups, sort, **kwargs):
41 42 43
    """Cross-db analysis metrics
    """
    logger.debug('ctx.meta: %s', ctx.meta)
44
    logger.debug('ctx.params: %s', ctx.params)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

    env = jinja2.Environment(undefined=jinja2.StrictUndefined)

    data = {
        'evaluation': [{'database': db, 'protocol': proto}
                       for db, proto in zip(databases, protocols)],
        'algorithm': algorithms,
        'group': groups,
    }

    metrics = {}

    for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
        logger.debug(variables)

        score_path = env.from_string(score_jinja_template).render(variables)
        logger.debug(score_path)

        database, protocol, algorithm, group = \
            variables['evaluation']['database'], \
            variables['evaluation']['protocol'], \
            variables['algorithm'], variables['group']

        # if algorithm name does not have train_database name in it.
        if train_database not in algorithm and database != train_database:
            score_path = score_path.replace(
                algorithm, database + '_' + algorithm)

        if not os.path.exists(score_path):
            metrics[(database, protocol, algorithm, group)] = \
                (float('nan'), ) * 5
            continue

        (neg, pos), fta = get_fta(split(score_path))

        if group == 'eval':
            threshold = metrics[(database, protocol, algorithm, 'dev')][1]
        else:
            threshold = eer_threshold(neg, pos)

        far, frr = farfrr(neg, pos, threshold)
        hter = (far + frr) / 2

        metrics[(database, protocol, algorithm, group)] = \
            (hter, threshold, fta, far, frr)

    logger.debug('metrics: %s', metrics)

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    headers = ["Algorithms"]
    for db in databases:
        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
    rows = []

    # sort the algorithms based on HTER test, EER dev, EER train
    if sort:
        train_protocol = protocols[databases.index(train_database)]

        def sort_key(alg):
            r = []
            for grp in ('eval', 'dev', 'train'):
                hter = metrics.get(
                    (train_database, train_protocol, alg, group))
                hter = hter if hter is None else hter[0]
                r.append(hter)
            return tuple(r)
        algorithms = sorted(algorithms, key=sort_key)
111 112

    for algorithm in algorithms:
113
        rows.append([algorithm.replace(train_database + '_', '')])
114
        for database, protocol in zip(databases, protocols):
115 116 117 118 119 120 121 122 123 124
            cell = []
            for group in groups:
                hter, threshold, fta, far, frr = metrics[(
                    database, protocol, algorithm, group)]
                if group == 'eval':
                    cell += [far, frr, hter]
                else:
                    cell += [hter]
            cell = [round(c * 100, 1) for c in cell]
            rows[-1].extend(cell)
125 126 127 128

    title = ' Trained on {} '.format(train_database)
    title_line = '\n' + '=' * len(title) + '\n'
    click.echo(title_line + title + title_line, file=ctx.meta['log'])
129
    click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
130
               file=ctx.meta['log'])