Skip to content
Snippets Groups Projects
Commit 3bc43285 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix the bug from last commit

parent eb7aef84
No related branches found
No related tags found
No related merge requests found
......@@ -5,13 +5,14 @@
from __future__ import print_function, absolute_import, division
import os
import numpy as np
from bob.io.base import create_directories_safe
from bob.measure.load import load_score, dump_score
from bob.bio.base import utils
from ..tools import parse_arguments, write_info, get_gza_from_lines_list, \
check_consistency, get_scores, remove_nan
check_consistency, get_scores, remove_nan, get_score_lines, get_2negatives_1positive
import bob.core
logger = bob.core.log.setup("bob.fusion.base")
......@@ -43,23 +44,28 @@ def fuse(args, command_line_parameters):
# check if score lines are consistent
if not args.skip_check:
logger.info('Checking the training files for consistency ...')
check_consistency(gen_lt, zei_lt, atk_lt)
if args.dev_files:
logger.info('Checking the development files for consistency ...')
check_consistency(gen_ld, zei_ld, atk_ld)
if args.eval_files:
logger.info('Checking the evaluation files for consistency ...')
check_consistency(gen_le, zei_le, atk_le)
scores_train = get_scores(score_lines_list_train)
scores_train = get_scores(gen_lt, zei_lt, atk_lt)
train_neg = get_scores(zei_lt, atk_lt)
train_pos = get_scores(gen_lt)
if args.dev_files:
scores_dev = get_scores(score_lines_list_dev)
scores_dev = get_scores(gen_ld, zei_ld, atk_ld)
scores_dev_lines = get_score_lines(gen_ld[0:1], zei_ld[0:1], atk_ld[0:1])
dev_neg = get_scores(zei_ld, atk_ld)
dev_pos = get_scores(gen_ld)
else:
dev_neg, dev_pos = None, None
if args.eval_files:
scores_eval = get_scores(score_lines_list_eval)
scores_eval = get_scores(gen_le, zei_le, atk_le)
scores_eval_lines = get_score_lines(gen_le[0:1], zei_le[0:1], atk_le[0:1])
# check for nan values
found_nan = False
......@@ -102,10 +108,13 @@ def fuse(args, command_line_parameters):
"- Fusion: scores '%s' already exists.", args.fused_dev_file)
elif args.dev_files:
fused_scores_dev = algorithm.fuse(scores_dev)
score_lines = score_lines_list_dev[idx1][~nan_dev]
score_lines = scores_dev_lines[~nan_dev]
score_lines['score'] = fused_scores_dev
gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines)
create_directories_safe(os.path.dirname(args.fused_dev_file))
dump_score(args.fused_dev_file, score_lines)
dump_score(args.fused_dev_file + '-licit', np.append(gen, zei))
dump_score(args.fused_dev_file + '-spoof', np.append(gen, atk))
# fuse the scores (eval)
if args.eval_files:
......@@ -114,10 +123,13 @@ def fuse(args, command_line_parameters):
"- Fusion: scores '%s' already exists.", args.fused_eval_file)
else:
fused_scores_eval = algorithm.fuse(scores_eval)
score_lines = score_lines_list_eval[idx1][~nan_eval]
score_lines = scores_eval_lines[~nan_eval]
score_lines['score'] = fused_scores_eval
gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines)
create_directories_safe(os.path.dirname(args.fused_eval_file))
dump_score(args.fused_eval_file, score_lines)
dump_score(args.fused_eval_file + '-licit', np.append(gen, zei))
dump_score(args.fused_eval_file + '-spoof', np.append(gen, atk))
def main(command_line_parameters=None):
......
import numpy as np
from collections import defaultdict
import bob.core
logger = bob.core.log.setup("bob.fusion.base")
......@@ -34,6 +35,34 @@ def get_scores(*args):
return np.vstack(scores).T
def get_score_lines(*args):
# get the dtype names
names = list(args[0][0].dtype.names)
if len(names) != 4:
names = [n for n in names if 'model_label' not in n]
logger.debug(names)
# find the (max) size of strigns
dtypes = [a.dtype for temp in zip(*args) for a in temp]
lengths = defaultdict(list)
for name in names:
for d in dtypes:
lengths[name].append(d[name].itemsize // 4)
# make a new dtype
new_dtype = []
for name in names[:-1]:
new_dtype.append((name, 'U{}'.format(max(lengths[name]))))
new_dtype.append((names[-1], float))
score_lines = []
for temp in zip(*args):
for a in temp:
score_lines.extend(a[names].tolist())
score_lines = np.array(score_lines, dtype=new_dtype)
return score_lines
def remove_nan(samples, found_nan):
ncls = samples.shape[1]
nans = np.isnan(samples[:, 0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment