Skip to content
Snippets Groups Projects
Commit a6b5be1d authored by Manuel Günther's avatar Manuel Günther
Browse files

Small restructuring and less repetitions

parent ec51cfc3
No related branches found
No related tags found
1 merge request!91Small restructuring and less repetitions
Pipeline #
...@@ -206,7 +206,7 @@ def add_jobs(args, submitter): ...@@ -206,7 +206,7 @@ def add_jobs(args, submitter):
dependencies = score_deps[group]) dependencies = score_deps[group])
concat_deps[group].extend([job_ids['score-%s-B'%group], job_ids['score-%s-C'%group], job_ids['score-%s-D'%group], job_ids['score-%s-Z'%group]]) concat_deps[group].extend([job_ids['score-%s-B'%group], job_ids['score-%s-C'%group], job_ids['score-%s-D'%group], job_ids['score-%s-Z'%group]])
else: else:
concat_deps[group] = [] concat_deps[group] = deps[:]
# concatenate results # concatenate results
if not args.skip_concatenation: if not args.skip_concatenation:
...@@ -309,48 +309,26 @@ def execute(args): ...@@ -309,48 +309,26 @@ def execute(args):
# enroll the models # enroll the models
elif args.sub_task == 'enroll': elif args.sub_task == 'enroll':
if args.model_type == 'N': model_ids = fs.model_ids(args.group) if args.model_type == 'N' else fs.t_model_ids(args.group)
tools.enroll( tools.enroll(
args.algorithm, args.algorithm,
args.extractor, args.extractor,
args.zt_norm, args.zt_norm,
indices = tools.indices(fs.model_ids(args.group), None if args.grid is None else args.grid.number_of_enrollment_jobs), indices = tools.indices(model_ids, None if args.grid is None else args.grid.number_of_enrollment_jobs),
groups = [args.group], groups = [args.group],
types = ['N'], types = [args.model_type],
allow_missing_files = args.allow_missing_files, allow_missing_files = args.allow_missing_files,
force = args.force) force = args.force)
else:
tools.enroll(
args.algorithm,
args.extractor,
args.zt_norm,
indices = tools.indices(fs.t_model_ids(args.group), None if args.grid is None else args.grid.number_of_enrollment_jobs),
groups = [args.group],
types = ['T'],
allow_missing_files = args.allow_missing_files,
force = args.force)
# compute scores # compute scores
elif args.sub_task == 'compute-scores': elif args.sub_task == 'compute-scores':
if args.score_type in ['A', 'B']: if args.score_type != 'Z':
tools.compute_scores( model_ids = fs.model_ids(args.group) if args.score_type in ('A', 'B') else fs.t_model_ids(args.group)
args.algorithm,
args.extractor,
args.zt_norm,
indices = tools.indices(fs.model_ids(args.group), None if args.grid is None else args.grid.number_of_scoring_jobs),
groups = [args.group],
types = [args.score_type],
force = args.force,
allow_missing_files = args.allow_missing_files,
write_compressed = args.write_compressed_score_files)
elif args.score_type in ['C', 'D']:
tools.compute_scores( tools.compute_scores(
args.algorithm, args.algorithm,
args.extractor, args.extractor,
args.zt_norm, args.zt_norm,
indices = tools.indices(fs.t_model_ids(args.group), None if args.grid is None else args.grid.number_of_scoring_jobs), indices = tools.indices(model_ids, None if args.grid is None else args.grid.number_of_scoring_jobs),
groups = [args.group], groups = [args.group],
types = [args.score_type], types = [args.score_type],
force = args.force, force = args.force,
......
...@@ -13,11 +13,13 @@ from .FileSelector import FileSelector ...@@ -13,11 +13,13 @@ from .FileSelector import FileSelector
from .extractor import read_features from .extractor import read_features
from .. import utils from .. import utils
def _scores(algorithm, reader, model, probes, allow_missing_files): def _scores(algorithm, reader, model, probe_objects, allow_missing_files):
"""Compute scores for the given model and a list of probes. """Compute scores for the given model and a list of probes.
""" """
# the file selector object # the file selector object
fs = FileSelector.instance() fs = FileSelector.instance()
# get probe files
probes = fs.get_paths(probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
# the scores to be computed; initialized with NaN # the scores to be computed; initialized with NaN
scores = numpy.ones((1,len(probes)), numpy.float64) * numpy.nan scores = numpy.ones((1,len(probes)), numpy.float64) * numpy.nan
...@@ -135,10 +137,8 @@ def _scores_a(algorithm, reader, model_ids, group, compute_zt_norm, force, write ...@@ -135,10 +137,8 @@ def _scores_a(algorithm, reader, model_ids, group, compute_zt_norm, force, write
model = None model = None
else: else:
model = algorithm.read_model(model_file) model = algorithm.read_model(model_file)
# get the probe files
current_probe_files = fs.get_paths(current_probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
# compute scores # compute scores
a = _scores(algorithm, reader, model, current_probe_files, allow_missing_files) a = _scores(algorithm, reader, model, current_probe_objects, allow_missing_files)
if compute_zt_norm: if compute_zt_norm:
# write A matrix only when you want to compute zt norm afterwards # write A matrix only when you want to compute zt norm afterwards
...@@ -155,7 +155,6 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files): ...@@ -155,7 +155,6 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files):
# probe files: # probe files:
z_probe_objects = fs.z_probe_objects(group) z_probe_objects = fs.z_probe_objects(group)
z_probe_files = fs.get_paths(z_probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
logger.info("- Scoring: computing score matrix B for group '%s'", group) logger.info("- Scoring: computing score matrix B for group '%s'", group)
...@@ -171,7 +170,7 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files): ...@@ -171,7 +170,7 @@ def _scores_b(algorithm, reader, model_ids, group, force, allow_missing_files):
model = None model = None
else: else:
model = algorithm.read_model(model_file) model = algorithm.read_model(model_file)
b = _scores(algorithm, reader, model, z_probe_files, allow_missing_files) b = _scores(algorithm, reader, model, z_probe_objects, allow_missing_files)
bob.io.base.save(b, score_file, True) bob.io.base.save(b, score_file, True)
def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files): def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files):
...@@ -181,7 +180,6 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files) ...@@ -181,7 +180,6 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files)
# probe files: # probe files:
probe_objects = fs.probe_objects(group) probe_objects = fs.probe_objects(group)
probe_files = fs.get_paths(probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
logger.info("- Scoring: computing score matrix C for group '%s'", group) logger.info("- Scoring: computing score matrix C for group '%s'", group)
...@@ -197,7 +195,7 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files) ...@@ -197,7 +195,7 @@ def _scores_c(algorithm, reader, t_model_ids, group, force, allow_missing_files)
t_model = None t_model = None
else: else:
t_model = algorithm.read_model(t_model_file) t_model = algorithm.read_model(t_model_file)
c = _scores(algorithm, reader, t_model, probe_files, allow_missing_files) c = _scores(algorithm, reader, t_model, probe_objects, allow_missing_files)
bob.io.base.save(c, score_file, True) bob.io.base.save(c, score_file, True)
def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files): def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files):
...@@ -207,7 +205,6 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files) ...@@ -207,7 +205,6 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files)
# probe files: # probe files:
z_probe_objects = fs.z_probe_objects(group) z_probe_objects = fs.z_probe_objects(group)
z_probe_files = fs.get_paths(z_probe_objects, 'projected' if algorithm.performs_projection else 'extracted')
logger.info("- Scoring: computing score matrix D for group '%s'", group) logger.info("- Scoring: computing score matrix D for group '%s'", group)
...@@ -227,7 +224,7 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files) ...@@ -227,7 +224,7 @@ def _scores_d(algorithm, reader, t_model_ids, group, force, allow_missing_files)
t_model = None t_model = None
else: else:
t_model = algorithm.read_model(t_model_file) t_model = algorithm.read_model(t_model_file)
d = _scores(algorithm, reader, t_model, z_probe_files, allow_missing_files) d = _scores(algorithm, reader, t_model, z_probe_objects, allow_missing_files)
bob.io.base.save(d, score_file, True) bob.io.base.save(d, score_file, True)
t_client_id = [fs.client_id(t_model_id, group, True)] t_client_id = [fs.client_id(t_model_id, group, True)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment