From 3bc432858ae06dc4e5db7fefe11c13a186c6297a Mon Sep 17 00:00:00 2001 From: Amir Mohammadi <amir.mohammadi@idiap.ch> Date: Fri, 2 Dec 2016 15:27:42 +0100 Subject: [PATCH] Fix the bug from last commit --- bob/fusion/base/script/bob_fuse.py | 24 ++++++++++++++++++------ bob/fusion/base/tools/common.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/bob/fusion/base/script/bob_fuse.py b/bob/fusion/base/script/bob_fuse.py index a456a18..1d314b2 100755 --- a/bob/fusion/base/script/bob_fuse.py +++ b/bob/fusion/base/script/bob_fuse.py @@ -5,13 +5,14 @@ from __future__ import print_function, absolute_import, division import os +import numpy as np from bob.io.base import create_directories_safe from bob.measure.load import load_score, dump_score from bob.bio.base import utils from ..tools import parse_arguments, write_info, get_gza_from_lines_list, \ - check_consistency, get_scores, remove_nan + check_consistency, get_scores, remove_nan, get_score_lines, get_2negatives_1positive import bob.core logger = bob.core.log.setup("bob.fusion.base") @@ -43,23 +44,28 @@ def fuse(args, command_line_parameters): # check if score lines are consistent if not args.skip_check: + logger.info('Checking the training files for consistency ...') check_consistency(gen_lt, zei_lt, atk_lt) if args.dev_files: + logger.info('Checking the development files for consistency ...') check_consistency(gen_ld, zei_ld, atk_ld) if args.eval_files: + logger.info('Checking the evaluation files for consistency ...') check_consistency(gen_le, zei_le, atk_le) - scores_train = get_scores(score_lines_list_train) + scores_train = get_scores(gen_lt, zei_lt, atk_lt) train_neg = get_scores(zei_lt, atk_lt) train_pos = get_scores(gen_lt) if args.dev_files: - scores_dev = get_scores(score_lines_list_dev) + scores_dev = get_scores(gen_ld, zei_ld, atk_ld) + scores_dev_lines = get_score_lines(gen_ld[0:1], zei_ld[0:1], atk_ld[0:1]) dev_neg = get_scores(zei_ld, atk_ld) dev_pos = get_scores(gen_ld) else: dev_neg, dev_pos = None, None if args.eval_files: - scores_eval = get_scores(score_lines_list_eval) + scores_eval = get_scores(gen_le, zei_le, atk_le) + scores_eval_lines = get_score_lines(gen_le[0:1], zei_le[0:1], atk_le[0:1]) # check for nan values found_nan = False @@ -102,10 +108,13 @@ def fuse(args, command_line_parameters): "- Fusion: scores '%s' already exists.", args.fused_dev_file) elif args.dev_files: fused_scores_dev = algorithm.fuse(scores_dev) - score_lines = score_lines_list_dev[idx1][~nan_dev] + score_lines = scores_dev_lines[~nan_dev] score_lines['score'] = fused_scores_dev + gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines) create_directories_safe(os.path.dirname(args.fused_dev_file)) dump_score(args.fused_dev_file, score_lines) + dump_score(args.fused_dev_file + '-licit', np.append(gen, zei)) + dump_score(args.fused_dev_file + '-spoof', np.append(gen, atk)) # fuse the scores (eval) if args.eval_files: @@ -114,10 +123,13 @@ def fuse(args, command_line_parameters): "- Fusion: scores '%s' already exists.", args.fused_eval_file) else: fused_scores_eval = algorithm.fuse(scores_eval) - score_lines = score_lines_list_eval[idx1][~nan_eval] + score_lines = scores_eval_lines[~nan_eval] score_lines['score'] = fused_scores_eval + gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines) create_directories_safe(os.path.dirname(args.fused_eval_file)) dump_score(args.fused_eval_file, score_lines) + dump_score(args.fused_eval_file + '-licit', np.append(gen, zei)) + dump_score(args.fused_eval_file + '-spoof', np.append(gen, atk)) def main(command_line_parameters=None): diff --git a/bob/fusion/base/tools/common.py b/bob/fusion/base/tools/common.py index 09235eb..f440fb2 100644 --- a/bob/fusion/base/tools/common.py +++ b/bob/fusion/base/tools/common.py @@ -1,4 +1,5 @@ import numpy as np +from collections import defaultdict import bob.core logger = bob.core.log.setup("bob.fusion.base") @@ -34,6 +35,34 @@ def get_scores(*args): return np.vstack(scores).T +def get_score_lines(*args): + # get the dtype names + names = list(args[0][0].dtype.names) + if len(names) != 4: + names = [n for n in names if 'model_label' not in n] + logger.debug(names) + + # find the (max) size of strigns + dtypes = [a.dtype for temp in zip(*args) for a in temp] + lengths = defaultdict(list) + for name in names: + for d in dtypes: + lengths[name].append(d[name].itemsize // 4) + + # make a new dtype + new_dtype = [] + for name in names[:-1]: + new_dtype.append((name, 'U{}'.format(max(lengths[name])))) + new_dtype.append((names[-1], float)) + + score_lines = [] + for temp in zip(*args): + for a in temp: + score_lines.extend(a[names].tolist()) + score_lines = np.array(score_lines, dtype=new_dtype) + return score_lines + + def remove_nan(samples, found_nan): ncls = samples.shape[1] nans = np.isnan(samples[:, 0]) -- GitLab