Skip to content
Snippets Groups Projects
Commit 83a87b37 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

improve behaviour and tests

parent 71ad129a
No related branches found
No related tags found
1 merge request!8improve behaviour and tests
Pipeline #
......@@ -60,15 +60,15 @@ def routine_fusion(
scores_dev_lines=None, scores_dev=None, dev_neg=None, dev_pos=None,
fused_dev_file=None,
scores_eval_lines=None, scores_eval=None, fused_eval_file=None,
force=False, min_file_size=1000):
force=False, min_file_size=1000, do_training=True):
# load the model if model_file exists and no training data was provided
if scores_train is None and os.path.exists(model_file):
if os.path.exists(model_file) and not do_training:
logger.info("Loading the algorithm from %s", model_file)
algorithm = algorithm.load(model_file)
# train the preprocessors
if train_neg is not None:
if train_neg is not None and do_training:
train_scores = np.vstack((train_neg, train_pos))
neg_len = train_neg.shape[0]
y = np.zeros((train_scores.shape[0],), dtype='bool')
......@@ -90,7 +90,7 @@ def routine_fusion(
scores_eval = algorithm.preprocess(scores_eval)
# Train the classifier
if train_neg is not None:
if train_neg is not None and do_training:
if utils.check_file(model_file, force, min_file_size):
logger.info(
"model '%s' already exists.", model_file)
......@@ -182,7 +182,8 @@ def fuse(scores, algorithm, groups, output_dir, model_file, skip_check, force,
The list of score files. The scores must correspond to the groups
parameter and scores of each system will come after the last one.
algorithm : :any:`bob.fusion.algorithm.Algorithm`
The fusion algorithm.
The fusion algorithm. It can be provided using `bob.fusion.algorithm`
setuptools entry-points or config files.
groups : [str]
The groups of the scores. This should correspond to the scores that are
provided. The order of options are important and should be in the same
......@@ -196,6 +197,7 @@ def fuse(scores, algorithm, groups, output_dir, model_file, skip_check, force,
force : bool, optional
Whether to overwrite existing files.
\b
Raises
------
click.BadArgumentUsage
......@@ -205,7 +207,10 @@ def fuse(scores, algorithm, groups, output_dir, model_file, skip_check, force,
"""
create_directories_safe(output_dir)
if not model_file:
do_training = True
model_file = os.path.join(output_dir, 'Model.pkl')
else:
do_training = False
fused_train_file = os.path.join(output_dir, 'scores-train')
fused_dev_file = os.path.join(output_dir, 'scores-dev')
fused_eval_file = os.path.join(output_dir, 'scores-eval')
......@@ -319,4 +324,4 @@ def fuse(scores, algorithm, groups, output_dir, model_file, skip_check, force,
algorithm, model_file, scores_train_lines, scores_train,
train_neg, train_pos, fused_train_file, scores_dev_lines,
scores_dev, dev_neg, dev_pos, fused_dev_file, scores_eval_lines,
scores_eval, fused_eval_file, force)
scores_eval, fused_eval_file, force, do_training=do_training)
......@@ -35,22 +35,39 @@ def compare_scores(path1, path2):
assert all(score1[name] == score2[name])
def click_result(result):
return "%s, %s, %s" % (result.exit_code, result.output, result.exception)
def test_fuse():
runner = CliRunner()
with runner.isolated_filesystem():
fused_train_file = os.path.join('fusion_result', 'scores-train')
fused_eval_file = os.path.join('fusion_result', 'scores-eval')
# Test with training
cmd = [x for xy in zip(train_files, eval_files) for x in xy] + \
['-g', 'train', '-g', 'eval', '-a', 'llr']
for _ in range(2):
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0
compare_scores(fused_train_file, fused_train_files[0])
compare_scores(fused_train_file + '-licit', fused_train_files[1])
compare_scores(fused_train_file + '-spoof', fused_train_files[2])
compare_scores(fused_eval_file, fused_eval_files[0])
compare_scores(fused_eval_file + '-licit', fused_eval_files[1])
compare_scores(fused_eval_file + '-spoof', fused_eval_files[2])
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0, click_result(result)
compare_scores(fused_train_file, fused_train_files[0])
compare_scores(fused_train_file + '-licit', fused_train_files[1])
compare_scores(fused_train_file + '-spoof', fused_train_files[2])
compare_scores(fused_eval_file, fused_eval_files[0])
compare_scores(fused_eval_file + '-licit', fused_eval_files[1])
compare_scores(fused_eval_file + '-spoof', fused_eval_files[2])
# Test without training
cmd = eval_files + ['-g', 'eval', '-a',
'llr', '-m', 'fusion_result/Model.pkl']
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0, click_result(result)
compare_scores(fused_train_file, fused_train_files[0])
compare_scores(fused_train_file + '-licit', fused_train_files[1])
compare_scores(fused_train_file + '-spoof', fused_train_files[2])
compare_scores(fused_eval_file, fused_eval_files[0])
compare_scores(fused_eval_file + '-licit', fused_eval_files[1])
compare_scores(fused_eval_file + '-spoof', fused_eval_files[2])
def test_fuse_train_only():
......@@ -60,7 +77,7 @@ def test_fuse_train_only():
cmd = train_files + \
['-g', 'train', '-a', 'llr']
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
compare_scores(fused_train_file, fused_train_files[0])
compare_scores(fused_train_file + '-licit', fused_train_files[1])
compare_scores(fused_train_file + '-spoof', fused_train_files[2])
......@@ -72,7 +89,7 @@ def test_fuse_with_dev():
cmd = train_files + train_files + \
['-g', 'train', '-g', 'dev', '-a', 'llr']
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
def test_fuse_inconsistent():
......@@ -105,7 +122,7 @@ def test_fuse_inconsistent():
cmd = train_files[0:1] + [wrong_train2] + \
['-g', 'train', '-a', 'llr', '--skip-check']
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0, result.exit_code
assert result.exit_code == 0, click_result(result)
assert not result.exception, result.exception
......@@ -115,12 +132,12 @@ def test_boundary():
cmd = train_files + \
['-g', 'train', '-a', 'llr']
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
model_file = 'fusion_result/Model.pkl'
cmd = eval_files + ['-m', model_file, '-t', '0']
result = runner.invoke(boundary, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
def test_boundary_grouping():
......@@ -129,15 +146,15 @@ def test_boundary_grouping():
cmd = train_files + \
['-g', 'train', '-a', 'llr']
result = runner.invoke(fuse, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
model_file = 'fusion_result/Model.pkl'
cmd1 = eval_files + ['-m', model_file, '-t', '0']
cmd = cmd1 + ['-G', 'random', '-g', '50']
result = runner.invoke(boundary, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
cmd = cmd1 + ['-G', 'kmeans', '-g', '50']
result = runner.invoke(boundary, cmd)
assert result.exit_code == 0
assert result.exit_code == 0, click_result(result)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment