From 6ed8639d11f758e03ee50f88d8d0aa85f44ecc57 Mon Sep 17 00:00:00 2001 From: Elie KHOURY <elie.khoury@idiap.ch> Date: Mon, 13 Jul 2015 17:36:27 +0200 Subject: [PATCH] improved script for fusion --- bob/bio/base/script/fusion_llr.py | 45 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/bob/bio/base/script/fusion_llr.py b/bob/bio/base/script/fusion_llr.py index 5e51ced0..de522015 100755 --- a/bob/bio/base/script/fusion_llr.py +++ b/bob/bio/base/script/fusion_llr.py @@ -73,19 +73,18 @@ def main(command_line_options = None): for i in range(n_systems): gen_data_dev.append({'4column' : bob.measure.load.four_column, '5column' : bob.measure.load.five_column}[args.parser](args.dev_files[i])) - data_dev = [] - for i in range(n_systems): - data_sys = [] - for (client_id, probe_id, file_id, score) in gen_data_dev[i]: - data_sys.append([client_id, probe_id, file_id, score]) - data_dev.append(data_sys) - - ndata = len(data_dev[0]) outf = open(args.score_fused_dev_file, 'w') - for k in range(ndata): - scores = numpy.array([[v[k][-1] for v in data_dev]], dtype=numpy.float64) + for line in gen_data_dev[0]: + scores= [] + claimed_id = line[0] + real_id = line[-3] + test_label = line[-2] + scores.append(line[-1]) + for n in range(1, n_systems): + scores.append(gen_data_dev[n].next()[-1]) + scores = numpy.array([scores], dtype=numpy.float64) s_fused = machine.forward(scores)[0,0] - line = " ".join(data_dev[0][k][0:-1]) + " " + str(s_fused) + "\n" + line = claimed_id + " " + real_id + " " + test_label + " " + str(s_fused) + "\n" outf.write(line) # fuse evaluation scores @@ -97,21 +96,19 @@ def main(command_line_options = None): for i in range(n_systems): gen_data_eval.append({'4column' : bob.measure.load.four_column, '5column' : bob.measure.load.five_column}[args.parser](args.eval_files[i])) - data_eval = [] - for i in range(n_systems): - data_sys = [] - for (client_id, probe_id, file_id, score) in gen_data_eval[i]: - data_sys.append([client_id, probe_id, file_id, score]) - data_eval.append(data_sys) - - ndata = len(data_eval[0]) outf = open(args.score_fused_eval_file, 'w') - for k in range(ndata): - scores = numpy.array([[v[k][-1] for v in data_eval]], dtype=numpy.float64) - s_fused = machine.forward(scores)[0,0] - line = " ".join(data_eval[0][k][0:-1]) + " " + str(s_fused) + "\n" + for line in gen_data_eval[0]: + scores= [] + claimed_id = line[0] + real_id = line[-3] + test_label = line[-2] + scores.append(line[-1]) + for n in range(1, n_systems): + scores.append(gen_data_eval[n].next()[-1]) + scores = numpy.array([scores], dtype=numpy.float64) + s_fused = machine.forward(scores)[0,0] + line = claimed_id + " " + real_id + " " + test_label + " " + str(s_fused) + "\n" outf.write(line) - return 0 if __name__ == '__main__': -- GitLab