Skip to content
Snippets Groups Projects
Commit 37726f48 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

More detailed and sorted table

parent ed8f110a
No related branches found
No related tags found
1 merge request!53Cross database testing evaluation
...@@ -9,7 +9,7 @@ from tabulate import tabulate ...@@ -9,7 +9,7 @@ from tabulate import tabulate
from bob.measure import eer_threshold, farfrr from bob.measure import eer_threshold, farfrr
from bob.measure.script import common_options from bob.measure.script import common_options
from bob.measure.utils import get_fta from bob.measure.utils import get_fta
from bob.extension.scripts.click_helper import verbosity_option from bob.extension.scripts.click_helper import verbosity_option, bool_option
from bob.bio.base.score.load import split from bob.bio.base.score.load import split
from gridtk.generator import expand from gridtk.generator import expand
...@@ -31,22 +31,17 @@ logger = logging.getLogger(__name__) ...@@ -31,22 +31,17 @@ logger = logging.getLogger(__name__)
help='The database that was used to train the algorithms.') help='The database that was used to train the algorithms.')
@click.option('-g', '--group', 'groups', multiple=True, show_default=True, @click.option('-g', '--group', 'groups', multiple=True, show_default=True,
default=['train', 'dev', 'eval']) default=['train', 'dev', 'eval'])
@bool_option('sort', 's', 'whether the table should be sorted.', True)
@common_options.table_option() @common_options.table_option()
@common_options.output_log_metric_option() @common_options.output_log_metric_option()
@verbosity_option() @verbosity_option()
@click.pass_context @click.pass_context
def cross(ctx, score_jinja_template, databases, protocols, algorithms, def cross(ctx, score_jinja_template, databases, protocols, algorithms,
train_database, groups, **kwargs): train_database, groups, sort, **kwargs):
"""Cross-db analysis metrics """Cross-db analysis metrics
""" """
logger.debug('ctx.meta: %s', ctx.meta) logger.debug('ctx.meta: %s', ctx.meta)
logger.debug('score_jinja_template: %s', score_jinja_template) logger.debug('ctx.params: %s', ctx.params)
logger.debug('databases: %s', databases)
logger.debug('protocols: %s', protocols)
logger.debug('algorithms: %s', algorithms)
logger.debug('train_database: %s', train_database)
logger.debug('groups: %s', groups)
logger.debug('kwargs: %s', kwargs)
env = jinja2.Environment(undefined=jinja2.StrictUndefined) env = jinja2.Environment(undefined=jinja2.StrictUndefined)
...@@ -95,19 +90,41 @@ def cross(ctx, score_jinja_template, databases, protocols, algorithms, ...@@ -95,19 +90,41 @@ def cross(ctx, score_jinja_template, databases, protocols, algorithms,
logger.debug('metrics: %s', metrics) logger.debug('metrics: %s', metrics)
headers = ["Algorithms"] + list(databases) headers = ["Algorithms"]
raws = [] for db in databases:
headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
rows = []
# sort the algorithms based on HTER test, EER dev, EER train
if sort:
train_protocol = protocols[databases.index(train_database)]
def sort_key(alg):
r = []
for grp in ('eval', 'dev', 'train'):
hter = metrics.get(
(train_database, train_protocol, alg, group))
hter = hter if hter is None else hter[0]
r.append(hter)
return tuple(r)
algorithms = sorted(algorithms, key=sort_key)
for algorithm in algorithms: for algorithm in algorithms:
raws.append([algorithm]) rows.append([algorithm.replace(train_database + '_', '')])
for database, protocol in zip(databases, protocols): for database, protocol in zip(databases, protocols):
cell = ['{:>5.1f}'.format( cell = []
100 * metrics[(database, protocol, algorithm, group)][0]) for group in groups:
for group in groups] hter, threshold, fta, far, frr = metrics[(
raws[-1].append(' '.join(cell)) database, protocol, algorithm, group)]
if group == 'eval':
cell += [far, frr, hter]
else:
cell += [hter]
cell = [round(c * 100, 1) for c in cell]
rows[-1].extend(cell)
title = ' Trained on {} '.format(train_database) title = ' Trained on {} '.format(train_database)
title_line = '\n' + '=' * len(title) + '\n' title_line = '\n' + '=' * len(title) + '\n'
click.echo(title_line + title + title_line, file=ctx.meta['log']) click.echo(title_line + title + title_line, file=ctx.meta['log'])
click.echo(tabulate(raws, headers, ctx.meta['tablefmt']), click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
file=ctx.meta['log']) file=ctx.meta['log'])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment