Skip to content
Snippets Groups Projects
Commit 3bc43285 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix the bug from last commit

parent eb7aef84
No related branches found
No related tags found
No related merge requests found
...@@ -5,13 +5,14 @@ ...@@ -5,13 +5,14 @@
from __future__ import print_function, absolute_import, division from __future__ import print_function, absolute_import, division
import os import os
import numpy as np
from bob.io.base import create_directories_safe from bob.io.base import create_directories_safe
from bob.measure.load import load_score, dump_score from bob.measure.load import load_score, dump_score
from bob.bio.base import utils from bob.bio.base import utils
from ..tools import parse_arguments, write_info, get_gza_from_lines_list, \ from ..tools import parse_arguments, write_info, get_gza_from_lines_list, \
check_consistency, get_scores, remove_nan check_consistency, get_scores, remove_nan, get_score_lines, get_2negatives_1positive
import bob.core import bob.core
logger = bob.core.log.setup("bob.fusion.base") logger = bob.core.log.setup("bob.fusion.base")
...@@ -43,23 +44,28 @@ def fuse(args, command_line_parameters): ...@@ -43,23 +44,28 @@ def fuse(args, command_line_parameters):
# check if score lines are consistent # check if score lines are consistent
if not args.skip_check: if not args.skip_check:
logger.info('Checking the training files for consistency ...')
check_consistency(gen_lt, zei_lt, atk_lt) check_consistency(gen_lt, zei_lt, atk_lt)
if args.dev_files: if args.dev_files:
logger.info('Checking the development files for consistency ...')
check_consistency(gen_ld, zei_ld, atk_ld) check_consistency(gen_ld, zei_ld, atk_ld)
if args.eval_files: if args.eval_files:
logger.info('Checking the evaluation files for consistency ...')
check_consistency(gen_le, zei_le, atk_le) check_consistency(gen_le, zei_le, atk_le)
scores_train = get_scores(score_lines_list_train) scores_train = get_scores(gen_lt, zei_lt, atk_lt)
train_neg = get_scores(zei_lt, atk_lt) train_neg = get_scores(zei_lt, atk_lt)
train_pos = get_scores(gen_lt) train_pos = get_scores(gen_lt)
if args.dev_files: if args.dev_files:
scores_dev = get_scores(score_lines_list_dev) scores_dev = get_scores(gen_ld, zei_ld, atk_ld)
scores_dev_lines = get_score_lines(gen_ld[0:1], zei_ld[0:1], atk_ld[0:1])
dev_neg = get_scores(zei_ld, atk_ld) dev_neg = get_scores(zei_ld, atk_ld)
dev_pos = get_scores(gen_ld) dev_pos = get_scores(gen_ld)
else: else:
dev_neg, dev_pos = None, None dev_neg, dev_pos = None, None
if args.eval_files: if args.eval_files:
scores_eval = get_scores(score_lines_list_eval) scores_eval = get_scores(gen_le, zei_le, atk_le)
scores_eval_lines = get_score_lines(gen_le[0:1], zei_le[0:1], atk_le[0:1])
# check for nan values # check for nan values
found_nan = False found_nan = False
...@@ -102,10 +108,13 @@ def fuse(args, command_line_parameters): ...@@ -102,10 +108,13 @@ def fuse(args, command_line_parameters):
"- Fusion: scores '%s' already exists.", args.fused_dev_file) "- Fusion: scores '%s' already exists.", args.fused_dev_file)
elif args.dev_files: elif args.dev_files:
fused_scores_dev = algorithm.fuse(scores_dev) fused_scores_dev = algorithm.fuse(scores_dev)
score_lines = score_lines_list_dev[idx1][~nan_dev] score_lines = scores_dev_lines[~nan_dev]
score_lines['score'] = fused_scores_dev score_lines['score'] = fused_scores_dev
gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines)
create_directories_safe(os.path.dirname(args.fused_dev_file)) create_directories_safe(os.path.dirname(args.fused_dev_file))
dump_score(args.fused_dev_file, score_lines) dump_score(args.fused_dev_file, score_lines)
dump_score(args.fused_dev_file + '-licit', np.append(gen, zei))
dump_score(args.fused_dev_file + '-spoof', np.append(gen, atk))
# fuse the scores (eval) # fuse the scores (eval)
if args.eval_files: if args.eval_files:
...@@ -114,10 +123,13 @@ def fuse(args, command_line_parameters): ...@@ -114,10 +123,13 @@ def fuse(args, command_line_parameters):
"- Fusion: scores '%s' already exists.", args.fused_eval_file) "- Fusion: scores '%s' already exists.", args.fused_eval_file)
else: else:
fused_scores_eval = algorithm.fuse(scores_eval) fused_scores_eval = algorithm.fuse(scores_eval)
score_lines = score_lines_list_eval[idx1][~nan_eval] score_lines = scores_eval_lines[~nan_eval]
score_lines['score'] = fused_scores_eval score_lines['score'] = fused_scores_eval
gen, zei, atk, _, _, _ = get_2negatives_1positive(score_lines)
create_directories_safe(os.path.dirname(args.fused_eval_file)) create_directories_safe(os.path.dirname(args.fused_eval_file))
dump_score(args.fused_eval_file, score_lines) dump_score(args.fused_eval_file, score_lines)
dump_score(args.fused_eval_file + '-licit', np.append(gen, zei))
dump_score(args.fused_eval_file + '-spoof', np.append(gen, atk))
def main(command_line_parameters=None): def main(command_line_parameters=None):
......
import numpy as np import numpy as np
from collections import defaultdict
import bob.core import bob.core
logger = bob.core.log.setup("bob.fusion.base") logger = bob.core.log.setup("bob.fusion.base")
...@@ -34,6 +35,34 @@ def get_scores(*args): ...@@ -34,6 +35,34 @@ def get_scores(*args):
return np.vstack(scores).T return np.vstack(scores).T
def get_score_lines(*args):
# get the dtype names
names = list(args[0][0].dtype.names)
if len(names) != 4:
names = [n for n in names if 'model_label' not in n]
logger.debug(names)
# find the (max) size of strigns
dtypes = [a.dtype for temp in zip(*args) for a in temp]
lengths = defaultdict(list)
for name in names:
for d in dtypes:
lengths[name].append(d[name].itemsize // 4)
# make a new dtype
new_dtype = []
for name in names[:-1]:
new_dtype.append((name, 'U{}'.format(max(lengths[name]))))
new_dtype.append((names[-1], float))
score_lines = []
for temp in zip(*args):
for a in temp:
score_lines.extend(a[names].tolist())
score_lines = np.array(score_lines, dtype=new_dtype)
return score_lines
def remove_nan(samples, found_nan): def remove_nan(samples, found_nan):
ncls = samples.shape[1] ncls = samples.shape[1]
nans = np.isnan(samples[:, 0]) nans = np.isnan(samples[:, 0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment