From 37726f4892d049dae6b50e8ec52d8fcdba6e90ff Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Mon, 21 May 2018 12:17:12 +0200 Subject: [PATCH] More detailed and sorted table --- bob/pad/base/script/cross.py | 51 ++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/bob/pad/base/script/cross.py b/bob/pad/base/script/cross.py index e21ecc5..aacb939 100644 --- a/bob/pad/base/script/cross.py +++ b/bob/pad/base/script/cross.py @@ -9,7 +9,7 @@ from tabulate import tabulate from bob.measure import eer_threshold, farfrr from bob.measure.script import common_options from bob.measure.utils import get_fta -from bob.extension.scripts.click_helper import verbosity_option +from bob.extension.scripts.click_helper import verbosity_option, bool_option from bob.bio.base.score.load import split from gridtk.generator import expand @@ -31,22 +31,17 @@ logger = logging.getLogger(__name__) help='The database that was used to train the algorithms.') @click.option('-g', '--group', 'groups', multiple=True, show_default=True, default=['train', 'dev', 'eval']) +@bool_option('sort', 's', 'whether the table should be sorted.', True) @common_options.table_option() @common_options.output_log_metric_option() @verbosity_option() @click.pass_context def cross(ctx, score_jinja_template, databases, protocols, algorithms, - train_database, groups, **kwargs): + train_database, groups, sort, **kwargs): """Cross-db analysis metrics """ logger.debug('ctx.meta: %s', ctx.meta) - logger.debug('score_jinja_template: %s', score_jinja_template) - logger.debug('databases: %s', databases) - logger.debug('protocols: %s', protocols) - logger.debug('algorithms: %s', algorithms) - logger.debug('train_database: %s', train_database) - logger.debug('groups: %s', groups) - logger.debug('kwargs: %s', kwargs) + logger.debug('ctx.params: %s', ctx.params) env = jinja2.Environment(undefined=jinja2.StrictUndefined) @@ -95,19 +90,41 @@ def cross(ctx, score_jinja_template, databases, protocols, algorithms, logger.debug('metrics: %s', metrics) - headers = ["Algorithms"] + list(databases) - raws = [] + headers = ["Algorithms"] + for db in databases: + headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"] + rows = [] + + # sort the algorithms based on HTER test, EER dev, EER train + if sort: + train_protocol = protocols[databases.index(train_database)] + + def sort_key(alg): + r = [] + for grp in ('eval', 'dev', 'train'): + hter = metrics.get( + (train_database, train_protocol, alg, group)) + hter = hter if hter is None else hter[0] + r.append(hter) + return tuple(r) + algorithms = sorted(algorithms, key=sort_key) for algorithm in algorithms: - raws.append([algorithm]) + rows.append([algorithm.replace(train_database + '_', '')]) for database, protocol in zip(databases, protocols): - cell = ['{:>5.1f}'.format( - 100 * metrics[(database, protocol, algorithm, group)][0]) - for group in groups] - raws[-1].append(' '.join(cell)) + cell = [] + for group in groups: + hter, threshold, fta, far, frr = metrics[( + database, protocol, algorithm, group)] + if group == 'eval': + cell += [far, frr, hter] + else: + cell += [hter] + cell = [round(c * 100, 1) for c in cell] + rows[-1].extend(cell) title = ' Trained on {} '.format(train_database) title_line = '\n' + '=' * len(title) + '\n' click.echo(title_line + title + title_line, file=ctx.meta['log']) - click.echo(tabulate(raws, headers, ctx.meta['tablefmt']), + click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"), file=ctx.meta['log']) -- GitLab