Commit a6b5be1d authored by Manuel Günther's avatar Manuel Günther
Browse files

Small restructuring and less repetitions

parent ec51cfc3
Pipeline #11579 passed with stages
in 18 minutes and 37 seconds
......@@ -206,7 +206,7 @@ def add_jobs(args, submitter):
dependencies = score_deps[group])
concat_deps[group].extend([job_ids['score-%s-B'%group], job_ids['score-%s-C'%group], job_ids['score-%s-D'%group], job_ids['score-%s-Z'%group]])
else:
concat_deps[group] = []
concat_deps[group] = deps[:]
# concatenate results
if not args.skip_concatenation:
......@@ -309,48 +309,26 @@ def execute(args):
# enroll the models
elif args.sub_task == 'enroll':
if args.model_type == 'N':
model_ids = fs.model_ids(args.group) if args.model_type == 'N' else fs.t_model_ids(args.group)
tools.enroll(
args.algorithm,
args.extractor,
args.zt_norm,
indices = tools.indices(fs.model_ids(args.group), None if args.grid is None else args.grid.number_of_enrollment_jobs),
indices = tools.indices(model_ids, None if args.grid is None else args.grid.number_of_enrollment_jobs),
groups = [args.group],
types = ['N'],
allow_missing_files = args.allow_missing_files,
force = args.force)
else:
tools.enroll(
args.algorithm,
args.extractor,
args.zt_norm,
indices = tools.indices(fs.t_model_ids(args.group), None if args.grid is None else args.grid.number_of_enrollment_jobs),
groups = [args.group],
types = ['T'],
types = [args.model_type],
allow_missing_files = args.allow_missing_files,
force = args.force)
# compute scores
elif args.sub_task == 'compute-scores':
if args.score_type in ['A', 'B']:
tools.compute_scores(
args.algorithm,
args.extractor,
args.zt_norm,
indices = tools.indices(fs.model_ids(args.group), None if args.grid is None else args.grid.number_of_scoring_jobs),
groups = [args.group],
types = [args.score_type],
force = args.force,
allow_missing_files = args.allow_missing_files,
write_compressed = args.write_compressed_score_files)
elif args.score_type in ['C', 'D']:
if args.score_type != 'Z':
model_ids = fs.model_ids(args.group) if args.score_type in ('A', 'B') else fs.t_model_ids(args.group)
tools.compute_scores(
args.algorithm,
args.extractor,
args.zt_norm,
indices = tools.indices(fs.t_model_ids(args.group), None if args.grid is None else args.grid.number_of_scoring_jobs),
indices = tools.indices(model_ids, None if args.grid is None else args.grid.number_of_scoring_jobs),
groups = [args.group],
types = [args.score_type],
force = args.force,
......
......@@ -13,11 +13,13 @@ from .FileSelector import FileSelector
from .extractor import read_features
from .. import utils
def _scores(algorithm, reader, model, probes, allow_missing_files):
def _scores(algorithm, reader, model, probe_objects, allow_missing_files):
"""Compute scores for the given model and a list of probes.
"""
# the file selector object
fs = FileSelector.instance()
# get probe files
probes = fs.get_paths(probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
# the scores to be computed; initialized with NaN
scores = numpy.ones((1,len(probes)), numpy.float64) * numpy.nan
......@@ -135,10 +137,8 @@ def _scores_a(algorithm, reader, model_ids, group, compute_zt_norm, force, write
model = None
else:
model = algorithm.read_model(model_file)
# get the probe files
current_probe_files = fs.get_paths(current_probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
# compute scores
a = _scores(algorithm, reader, model, current_probe_files, allow_missing_files)
a = _scores(algorithm, reader, model, current_probe_objects, allow_missing_files)
if compute_zt_norm:
# write A matrix only when you want to compute zt norm afterwards
......@@ -155,7 +155,6 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files):
# probe files:
z_probe_objects = fs.z_probe_objects(group)
z_probe_files = fs.get_paths(z_probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
logger.info("- Scoring: computing score matrix B for group '%s'", group)
......@@ -171,7 +170,7 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files):
model = None
else:
model = algorithm.read_model(model_file)
b = _scores(algorithm, reader, model, z_probe_files, allow_missing_files)
b = _scores(algorithm, reader, model, z_probe_objects, allow_missing_files)
bob.io.base.save(b, score_file, True)
def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files):
......@@ -181,7 +180,6 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files)
# probe files:
probe_objects = fs.probe_objects(group)
probe_files = fs.get_paths(probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
logger.info("- Scoring: computing score matrix C for group '%s'", group)
......@@ -197,7 +195,7 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files)
t_model = None
else:
t_model = algorithm.read_model(t_model_file)
c = _scores(algorithm, reader, t_model, probe_files, allow_missing_files)
c = _scores(algorithm, reader, t_model, probe_objects, allow_missing_files)
bob.io.base.save(c, score_file, True)
def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files):
......@@ -207,7 +205,6 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files)
# probe files:
z_probe_objects = fs.z_probe_objects(group)
z_probe_files = fs.get_paths(z_probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
logger.info("- Scoring: computing score matrix D for group '%s'", group)
......@@ -227,7 +224,7 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files)
t_model = None
else:
t_model = algorithm.read_model(t_model_file)
d = _scores(algorithm, reader, t_model, z_probe_files, allow_missing_files)
d = _scores(algorithm, reader, t_model, z_probe_objects, allow_missing_files)
bob.io.base.save(d, score_file, True)
t_client_id = [fs.client_id(t_model_id, group, True)]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment