Commit dc19fef7 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix repeats. Skip existing ones

parent eaae3554
......@@ -35,6 +35,8 @@ Options:
script is run in the SGE grid. You
should use this option with
``jman submit -t N``.
-f, --force If provided, it will overwrite the existing
predictions.
-v, --verbose Increases the output verbosity level
The configuration files should have the following objects totally:
......@@ -76,8 +78,13 @@ from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def make_output_path(output_dir, key):
return os.path.join(output_dir, key + '.hdf5')
def bio_generator(database, preprocessor, groups, number_of_parallel_jobs,
biofile_to_label, multiple_samples=False):
biofile_to_label, output_dir, multiple_samples=False,
force=False):
biofiles = list(database.all_files(groups))
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
......@@ -95,10 +102,13 @@ def bio_generator(database, preprocessor, groups, number_of_parallel_jobs,
def generator():
for f, label, key in six.moves.zip(biofiles, labels, keys):
outpath = make_output_path(output_dir, key)
if not force and os.path.isfile(outpath):
continue
data = load_data(f, preprocessor, database)
if multiple_samples:
label = tf_repeat([label], len(data))
key = tf_repeat([key], len(data))
label = [label for _ in range(len(data))]
key = [key for _ in range(len(data))]
yield (data, label, key)
# load one data to get its type and shape
......@@ -119,7 +129,7 @@ def bio_generator(database, preprocessor, groups, number_of_parallel_jobs,
def save_predictions(pool, output_dir, key, pred_buffer):
outpath = os.path.join(output_dir, key + '.hdf5')
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
pool.apply_async(save, (np.mean(pred_buffer[key], axis=0), outpath))
......@@ -145,6 +155,8 @@ def main(argv=None):
config, 'multiple_samples', args, defaults)
number_of_parallel_jobs = get_from_config_or_commandline(
config, 'number_of_parallel_jobs', args, defaults)
force = get_from_config_or_commandline(
config, 'force', args, defaults)
hooks = getattr(config, 'hooks', None)
# Sets-up logging
......@@ -162,7 +174,7 @@ def main(argv=None):
generator, output_types, output_shapes = bio_generator(
database, preprocessor, groups, number_of_parallel_jobs,
biofile_to_label, multiple_samples)
biofile_to_label, output_dir, multiple_samples, force)
predict_input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment