cross.py 4.7 KB
Newer Older
1 2 3
"""Prints Cross-db metrics analysis
"""
import click
4
import jinja2
5
import logging
6 7
import math
import os
8
import yaml
9 10
from bob.bio.base.score.load import split
from bob.extension.scripts.click_helper import verbosity_option, bool_option
11 12 13 14
from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options
from bob.measure.utils import get_fta
from gridtk.generator import expand
15
from tabulate import tabulate
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34

logger = logging.getLogger(__name__)


@click.command(context_settings=dict(token_normalize_func=lambda x: x.lower()))
@click.argument('score_jinja_template')
@click.option('-d', '--database', 'databases', multiple=True, required=True,
              show_default=True,
              help='Names of the evaluation databases')
@click.option('-p', '--protocol', 'protocols', multiple=True, required=True,
              show_default=True,
              help='Names of the protocols of the evaluation databases')
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
              show_default=True,
              help='Names of the algorithms')
@click.option('-td', '--train-database', required=True,
              help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
              default=['train', 'dev', 'eval'])
35
@bool_option('sort', 's', 'whether the table should be sorted.', True)
36
@common_options.table_option()
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
37
@common_options.output_log_metric_option()
38 39 40
@verbosity_option()
@click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
41
          train_database, groups, sort, **kwargs):
42 43 44
    """Cross-db analysis metrics
    """
    logger.debug('ctx.meta: %s', ctx.meta)
45
    logger.debug('ctx.params: %s', ctx.params)
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

    env = jinja2.Environment(undefined=jinja2.StrictUndefined)

    data = {
        'evaluation': [{'database': db, 'protocol': proto}
                       for db, proto in zip(databases, protocols)],
        'algorithm': algorithms,
        'group': groups,
    }

    metrics = {}

    for variables in expand(yaml.dump(data, Dumper=yaml.SafeDumper)):
        logger.debug(variables)

        score_path = env.from_string(score_jinja_template).render(variables)
        logger.debug(score_path)

        database, protocol, algorithm, group = \
            variables['evaluation']['database'], \
            variables['evaluation']['protocol'], \
            variables['algorithm'], variables['group']

        # if algorithm name does not have train_database name in it.
        if train_database not in algorithm and database != train_database:
            score_path = score_path.replace(
                algorithm, database + '_' + algorithm)

        if not os.path.exists(score_path):
            metrics[(database, protocol, algorithm, group)] = \
                (float('nan'), ) * 5
            continue

        (neg, pos), fta = get_fta(split(score_path))

        if group == 'eval':
            threshold = metrics[(database, protocol, algorithm, 'dev')][1]
        else:
            threshold = eer_threshold(neg, pos)

        far, frr = farfrr(neg, pos, threshold)
        hter = (far + frr) / 2

        metrics[(database, protocol, algorithm, group)] = \
            (hter, threshold, fta, far, frr)

    logger.debug('metrics: %s', metrics)

94 95 96 97 98 99 100 101 102 103 104 105
    headers = ["Algorithms"]
    for db in databases:
        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
    rows = []

    # sort the algorithms based on HTER test, EER dev, EER train
    if sort:
        train_protocol = protocols[databases.index(train_database)]

        def sort_key(alg):
            r = []
            for grp in ('eval', 'dev', 'train'):
106 107
                hter = metrics[(train_database, train_protocol, alg, group)][0]
                r.append(1 if math.isnan(hter) else hter)
108 109
            return tuple(r)
        algorithms = sorted(algorithms, key=sort_key)
110 111

    for algorithm in algorithms:
112
        rows.append([algorithm.replace(train_database + '_', '')])
113
        for database, protocol in zip(databases, protocols):
114 115 116 117 118 119 120 121 122 123
            cell = []
            for group in groups:
                hter, threshold, fta, far, frr = metrics[(
                    database, protocol, algorithm, group)]
                if group == 'eval':
                    cell += [far, frr, hter]
                else:
                    cell += [hter]
            cell = [round(c * 100, 1) for c in cell]
            rows[-1].extend(cell)
124 125 126 127

    title = ' Trained on {} '.format(train_database)
    title_line = '\n' + '=' * len(title) + '\n'
    click.echo(title_line + title + title_line, file=ctx.meta['log'])
128
    click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
129
               file=ctx.meta['log'])