From 3bc432858ae06dc4e5db7fefe11c13a186c6297a Mon Sep 17 00:00:00 2001
From: Amir Mohammadi <amir.mohammadi@idiap.ch>
Date: Fri, 2 Dec 2016 15:27:42 +0100
Subject: [PATCH] Fix the bug from last commit

---
 bob/fusion/base/script/bob_fuse.py | 24 ++++++++++++++++++------
 bob/fusion/base/tools/common.py    | 29 +++++++++++++++++++++++++++++
 2 files changed, 47 insertions(+), 6 deletions(-)

diff --git a/bob/fusion/base/script/bob_fuse.py b/bob/fusion/base/script/bob_fuse.py
index a456a18..1d314b2 100755
--- a/bob/fusion/base/script/bob_fuse.py
+++ b/bob/fusion/base/script/bob_fuse.py
@@ -5,13 +5,14 @@
 from __future__ import print_function, absolute_import, division
 
 import os
+import numpy as np
 
 from bob.io.base import create_directories_safe
 from bob.measure.load import load_score, dump_score
 from bob.bio.base import utils
 
 from ..tools import parse_arguments, write_info, get_gza_from_lines_list, \
-   check_consistency, get_scores, remove_nan
+   check_consistency, get_scores, remove_nan, get_score_lines, get_2negatives_1positive
 
 import bob.core
 logger = bob.core.log.setup("bob.fusion.base")
@@ -43,23 +44,28 @@ def fuse(args, command_line_parameters):
 
   # check if score lines are consistent
   if not args.skip_check:
+    logger.info('Checking the training files for consistency ...')
     check_consistency(gen_lt, zei_lt, atk_lt)
     if args.dev_files:
+      logger.info('Checking the development files for consistency ...')
       check_consistency(gen_ld, zei_ld, atk_ld)
     if args.eval_files:
+      logger.info('Checking the evaluation files for consistency ...')
       check_consistency(gen_le, zei_le, atk_le)
 
-  scores_train = get_scores(score_lines_list_train)
+  scores_train = get_scores(gen_lt, zei_lt, atk_lt)
   train_neg = get_scores(zei_lt, atk_lt)
   train_pos = get_scores(gen_lt)
   if args.dev_files:
-    scores_dev = get_scores(score_lines_list_dev)
+    scores_dev = get_scores(gen_ld, zei_ld, atk_ld)
+    scores_dev_lines = get_score_lines(gen_ld[0:1], zei_ld[0:1], atk_ld[0:1])
     dev_neg = get_scores(zei_ld, atk_ld)
     dev_pos = get_scores(gen_ld)
   else:
     dev_neg, dev_pos = None, None
   if args.eval_files:
-    scores_eval = get_scores(score_lines_list_eval)
+    scores_eval = get_scores(gen_le, zei_le, atk_le)
+    scores_eval_lines = get_score_lines(gen_le[0:1], zei_le[0:1], atk_le[0:1])
 
   # check for nan values
   found_nan = False
@@ -102,10 +108,13 @@ def fuse(args, command_line_parameters):
       "- Fusion: scores '%s' already exists.", args.fused_dev_file)
   elif args.dev_files:
     fused_scores_dev = algorithm.fuse(scores_dev)
-    score_lines = score_lines_list_dev[idx1][~nan_dev]
+    score_lines = scores_dev_lines[~nan_dev]
     score_lines['score'] = fused_scores_dev
+    gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines)
     create_directories_safe(os.path.dirname(args.fused_dev_file))
     dump_score(args.fused_dev_file, score_lines)
+    dump_score(args.fused_dev_file + '-licit', np.append(gen, zei))
+    dump_score(args.fused_dev_file + '-spoof', np.append(gen, atk))
 
   # fuse the scores (eval)
   if args.eval_files:
@@ -114,10 +123,13 @@ def fuse(args, command_line_parameters):
         "- Fusion: scores '%s' already exists.", args.fused_eval_file)
     else:
       fused_scores_eval = algorithm.fuse(scores_eval)
-      score_lines = score_lines_list_eval[idx1][~nan_eval]
+      score_lines = scores_eval_lines[~nan_eval]
       score_lines['score'] = fused_scores_eval
+      gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines)
       create_directories_safe(os.path.dirname(args.fused_eval_file))
       dump_score(args.fused_eval_file, score_lines)
+      dump_score(args.fused_eval_file + '-licit', np.append(gen, zei))
+      dump_score(args.fused_eval_file + '-spoof', np.append(gen, atk))
 
 
 def main(command_line_parameters=None):
diff --git a/bob/fusion/base/tools/common.py b/bob/fusion/base/tools/common.py
index 09235eb..f440fb2 100644
--- a/bob/fusion/base/tools/common.py
+++ b/bob/fusion/base/tools/common.py
@@ -1,4 +1,5 @@
 import numpy as np
+from collections import defaultdict
 
 import bob.core
 logger = bob.core.log.setup("bob.fusion.base")
@@ -34,6 +35,34 @@ def get_scores(*args):
   return np.vstack(scores).T
 
 
+def get_score_lines(*args):
+  # get the dtype names
+  names = list(args[0][0].dtype.names)
+  if len(names) != 4:
+    names = [n for n in names if 'model_label' not in n]
+  logger.debug(names)
+
+  # find the (max) size of strigns
+  dtypes = [a.dtype for temp in zip(*args) for a in temp]
+  lengths = defaultdict(list)
+  for name in names:
+    for d in dtypes:
+      lengths[name].append(d[name].itemsize // 4)
+
+  # make a new dtype
+  new_dtype = []
+  for name in names[:-1]:
+    new_dtype.append((name, 'U{}'.format(max(lengths[name]))))
+  new_dtype.append((names[-1], float))
+
+  score_lines = []
+  for temp in zip(*args):
+    for a in temp:
+      score_lines.extend(a[names].tolist())
+  score_lines = np.array(score_lines, dtype=new_dtype)
+  return score_lines
+
+
 def remove_nan(samples, found_nan):
   ncls = samples.shape[1]
   nans = np.isnan(samples[:, 0])
-- 
GitLab