Skip to content
Snippets Groups Projects
Commit a029591f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add support for renaming algs in bob pad cross

parent cb6d74b9
Branches
Tags
1 merge request!53Cross database testing evaluation
"""Prints Cross-db metrics analysis
"""
import click
import json
import jinja2
import logging
import math
......@@ -20,8 +21,8 @@ logger = logging.getLogger(__name__)
@click.command(epilog='''\b
Examples:
$ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' -td replaymobile -d replaymobile -p grandtest -d oulunpu -p Protocol_1 \
$ bin/bob pad cross 'results/{{ evaluation.database }}/{{ algorithm }}/{{ evaluation.protocol }}/scores/scores-{{ group }}' \
-td replaymobile -d replaymobile -p grandtest -d oulunpu -p Protocol_1 \
-a replaymobile_frame-diff-svm \
-a replaymobile_qm-svm-64 \
-a replaymobile_lbp-svm-64 \
......@@ -37,6 +38,10 @@ Examples:
@click.option('-a', '--algorithm', 'algorithms', multiple=True, required=True,
show_default=True,
help='Names of the algorithms')
@click.option('-n', '--names', type=click.File('r'),
help='Name of algorithms to show in the table. Provide a path '
'to a json file maps algorithm names to names that you want to '
'see in the table.')
@click.option('-td', '--train-database', required=True,
help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True,
......@@ -47,11 +52,13 @@ Examples:
@verbosity_option()
@click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms,
train_database, groups, sort, **kwargs):
names, train_database, groups, sort, **kwargs):
"""Cross-db analysis metrics
"""
log_parameters(logger)
names = {} if names is None else json.load(names)
env = jinja2.Environment(undefined=jinja2.StrictUndefined)
data = {
......@@ -89,7 +96,11 @@ def cross(ctx, score_jinja_template, databases, protocols, algorithms,
if group == 'eval':
threshold = metrics[(database, protocol, algorithm, 'dev')][1]
else:
try:
threshold = eer_threshold(neg, pos)
except RuntimeError:
logger.error("Something wrong with {}".format(score_path))
raise
far, frr = farfrr(neg, pos, threshold)
hter = (far + frr) / 2
......@@ -117,7 +128,10 @@ def cross(ctx, score_jinja_template, databases, protocols, algorithms,
algorithms = sorted(algorithms, key=sort_key)
for algorithm in algorithms:
rows.append([algorithm.replace(train_database + '_', '').replace(train_protocol + '_', '')])
name = algorithm.replace(train_database + '_', '')
name = name.replace(train_protocol + '_', '')
name = names.get(name, name)
rows.append([name])
for database, protocol in zip(databases, protocols):
cell = []
for group in groups:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment