From 6ed8639d11f758e03ee50f88d8d0aa85f44ecc57 Mon Sep 17 00:00:00 2001
From: Elie KHOURY <elie.khoury@idiap.ch>
Date: Mon, 13 Jul 2015 17:36:27 +0200
Subject: [PATCH] improved script for fusion

---
 bob/bio/base/script/fusion_llr.py | 45 +++++++++++++++----------------
 1 file changed, 21 insertions(+), 24 deletions(-)

diff --git a/bob/bio/base/script/fusion_llr.py b/bob/bio/base/script/fusion_llr.py
index 5e51ced0..de522015 100755
--- a/bob/bio/base/script/fusion_llr.py
+++ b/bob/bio/base/script/fusion_llr.py
@@ -73,19 +73,18 @@ def main(command_line_options = None):
   for i in range(n_systems):
     gen_data_dev.append({'4column' : bob.measure.load.four_column, '5column' : bob.measure.load.five_column}[args.parser](args.dev_files[i]))
 
-  data_dev = []
-  for i in range(n_systems):
-    data_sys = []
-    for (client_id, probe_id, file_id, score) in gen_data_dev[i]:
-      data_sys.append([client_id, probe_id, file_id, score])
-    data_dev.append(data_sys)
-  
-  ndata = len(data_dev[0])
   outf = open(args.score_fused_dev_file, 'w')
-  for k in range(ndata):
-    scores = numpy.array([[v[k][-1] for v in data_dev]], dtype=numpy.float64)
+  for line in gen_data_dev[0]:
+    scores= []
+    claimed_id = line[0]
+    real_id = line[-3]
+    test_label = line[-2]
+    scores.append(line[-1])
+    for n in range(1, n_systems): 
+      scores.append(gen_data_dev[n].next()[-1])
+    scores = numpy.array([scores], dtype=numpy.float64)
     s_fused = machine.forward(scores)[0,0]  
-    line = " ".join(data_dev[0][k][0:-1]) + " " + str(s_fused) + "\n"
+    line = claimed_id + " " + real_id + " " + test_label + " "  + str(s_fused) + "\n"
     outf.write(line)
 
   # fuse evaluation scores
@@ -97,21 +96,19 @@ def main(command_line_options = None):
     for i in range(n_systems):
       gen_data_eval.append({'4column' : bob.measure.load.four_column, '5column' : bob.measure.load.five_column}[args.parser](args.eval_files[i]))
       
-    data_eval = []
-    for i in range(n_systems):
-      data_sys = []
-      for (client_id, probe_id, file_id, score) in gen_data_eval[i]:
-        data_sys.append([client_id, probe_id, file_id, score])
-      data_eval.append(data_sys)
-      
-    ndata = len(data_eval[0])
     outf = open(args.score_fused_eval_file, 'w')
-    for k in range(ndata):
-      scores = numpy.array([[v[k][-1] for v in data_eval]], dtype=numpy.float64)
-      s_fused = machine.forward(scores)[0,0]
-      line = " ".join(data_eval[0][k][0:-1]) + " " + str(s_fused) + "\n"
+    for line in gen_data_eval[0]:
+      scores= []
+      claimed_id = line[0]
+      real_id = line[-3]
+      test_label = line[-2]
+      scores.append(line[-1])
+      for n in range(1, n_systems): 
+        scores.append(gen_data_eval[n].next()[-1])
+      scores = numpy.array([scores], dtype=numpy.float64)
+      s_fused = machine.forward(scores)[0,0]  
+      line = claimed_id + " " + real_id + " " + test_label + " "  + str(s_fused) + "\n"
       outf.write(line)
-
   return 0
 
 if __name__ == '__main__':
-- 
GitLab