From 37726f4892d049dae6b50e8ec52d8fcdba6e90ff Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 21 May 2018 12:17:12 +0200
Subject: [PATCH] More detailed and sorted table

---
 bob/pad/base/script/cross.py | 51 ++++++++++++++++++++++++------------
 1 file changed, 34 insertions(+), 17 deletions(-)

diff --git a/bob/pad/base/script/cross.py b/bob/pad/base/script/cross.py
index e21ecc5..aacb939 100644
--- a/bob/pad/base/script/cross.py
+++ b/bob/pad/base/script/cross.py
@@ -9,7 +9,7 @@ from tabulate import tabulate
 from bob.measure import eer_threshold, farfrr
 from bob.measure.script import common_options
 from bob.measure.utils import get_fta
-from bob.extension.scripts.click_helper import verbosity_option
+from bob.extension.scripts.click_helper import verbosity_option, bool_option
 from bob.bio.base.score.load import split
 from gridtk.generator import expand
 
@@ -31,22 +31,17 @@ logger = logging.getLogger(__name__)
               help='The database that was used to train the algorithms.')
 @click.option('-g', '--group', 'groups', multiple=True, show_default=True,
               default=['train', 'dev', 'eval'])
+@bool_option('sort', 's', 'whether the table should be sorted.', True)
 @common_options.table_option()
 @common_options.output_log_metric_option()
 @verbosity_option()
 @click.pass_context
 def cross(ctx, score_jinja_template, databases, protocols, algorithms,
-          train_database, groups, **kwargs):
+          train_database, groups, sort, **kwargs):
     """Cross-db analysis metrics
     """
     logger.debug('ctx.meta: %s', ctx.meta)
-    logger.debug('score_jinja_template: %s', score_jinja_template)
-    logger.debug('databases: %s', databases)
-    logger.debug('protocols: %s', protocols)
-    logger.debug('algorithms: %s', algorithms)
-    logger.debug('train_database: %s', train_database)
-    logger.debug('groups: %s', groups)
-    logger.debug('kwargs: %s', kwargs)
+    logger.debug('ctx.params: %s', ctx.params)
 
     env = jinja2.Environment(undefined=jinja2.StrictUndefined)
 
@@ -95,19 +90,41 @@ def cross(ctx, score_jinja_template, databases, protocols, algorithms,
 
     logger.debug('metrics: %s', metrics)
 
-    headers = ["Algorithms"] + list(databases)
-    raws = []
+    headers = ["Algorithms"]
+    for db in databases:
+        headers += [db + "\nEER_t", "\nEER_d", "\nAPCER", "\nBPCER", "\nACER"]
+    rows = []
+
+    # sort the algorithms based on HTER test, EER dev, EER train
+    if sort:
+        train_protocol = protocols[databases.index(train_database)]
+
+        def sort_key(alg):
+            r = []
+            for grp in ('eval', 'dev', 'train'):
+                hter = metrics.get(
+                    (train_database, train_protocol, alg, group))
+                hter = hter if hter is None else hter[0]
+                r.append(hter)
+            return tuple(r)
+        algorithms = sorted(algorithms, key=sort_key)
 
     for algorithm in algorithms:
-        raws.append([algorithm])
+        rows.append([algorithm.replace(train_database + '_', '')])
         for database, protocol in zip(databases, protocols):
-            cell = ['{:>5.1f}'.format(
-                100 * metrics[(database, protocol, algorithm, group)][0])
-                for group in groups]
-            raws[-1].append(' '.join(cell))
+            cell = []
+            for group in groups:
+                hter, threshold, fta, far, frr = metrics[(
+                    database, protocol, algorithm, group)]
+                if group == 'eval':
+                    cell += [far, frr, hter]
+                else:
+                    cell += [hter]
+            cell = [round(c * 100, 1) for c in cell]
+            rows[-1].extend(cell)
 
     title = ' Trained on {} '.format(train_database)
     title_line = '\n' + '=' * len(title) + '\n'
     click.echo(title_line + title + title_line, file=ctx.meta['log'])
-    click.echo(tabulate(raws, headers, ctx.meta['tablefmt']),
+    click.echo(tabulate(rows, headers, ctx.meta['tablefmt'], floatfmt=".1f"),
                file=ctx.meta['log'])
-- 
GitLab