Skip to content
Snippets Groups Projects
Commit 9d903ca0 authored by Saeed SARFJOO's avatar Saeed SARFJOO
Browse files

integrate the structure of train, eval, and predict

parent adff932a
Branches master
No related tags found
1 merge request!62Integrate the structure of train, eval, and predict
Pipeline #
...@@ -70,17 +70,17 @@ def save_predictions(pool, output_dir, key, pred_buffer): ...@@ -70,17 +70,17 @@ def save_predictions(pool, output_dir, key, pred_buffer):
@click.option( @click.option(
'--database', '--database',
'-d', '-d',
required=True, default=None,
cls=ResourceOption, cls=ResourceOption,
entry_point_group='bob.bio.database', entry_point_group='bob.bio.database',
help='A bio database. Its original_directory must point to the correct ' help='A bio database. Its original_directory must point to the correct '
'path.') 'path. If `None` the `bio_predict_input_fn` must have infos about database')
@click.option( @click.option(
'--biofiles', '--biofiles',
required=True, default=None,
cls=ResourceOption, cls=ResourceOption,
help='The list of the bio files. You can only provide this through config ' help='The list of the bio files. You can only provide this through config '
'files.') 'files. If `None` `bio_predict_input_fn` must have infos about biofiles')
@click.option( @click.option(
'--bio-predict-input-fn', '--bio-predict-input-fn',
required=True, required=True,
...@@ -90,7 +90,8 @@ def save_predictions(pool, output_dir, key, pred_buffer): ...@@ -90,7 +90,8 @@ def save_predictions(pool, output_dir, key, pred_buffer):
'``input_fn = bio_predict_input_fn(generator, output_types, output_shapes)``' '``input_fn = bio_predict_input_fn(generator, output_types, output_shapes)``'
' The inputs are documented in :any:`tf.data.Dataset.from_generator`' ' The inputs are documented in :any:`tf.data.Dataset.from_generator`'
' and the output should be a function with no arguments and is passed' ' and the output should be a function with no arguments and is passed'
' to :any:`tf.estimator.Estimator.predict`.') ' to :any:`tf.estimator.Estimator.predict`. If no input argument is given'
' bio_predict_input_fn is passed to :any:`tf.estimator.Estimator.predict`.')
@click.option( @click.option(
'--output-dir', '--output-dir',
'-o', '-o',
...@@ -99,11 +100,13 @@ def save_predictions(pool, output_dir, key, pred_buffer): ...@@ -99,11 +100,13 @@ def save_predictions(pool, output_dir, key, pred_buffer):
help='The directory to save the predictions.') help='The directory to save the predictions.')
@click.option( @click.option(
'--load-data', '--load-data',
default=None,
cls=ResourceOption, cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.load_data', entry_point_group='bob.learn.tensorflow.load_data',
help='A callable with the signature of ' help='A callable with the signature of '
'``data = load_data(database, biofile)``. ' '``data = load_data(database, biofile)``. '
':any:`bob.bio.base.read_original_data` is used by default.') ':any:`bob.bio.base.read_original_data` is used by default.'
' If `None` `bio_predict_input_fn` must have infos about load_data')
@click.option( @click.option(
'--hooks', '--hooks',
cls=ResourceOption, cls=ResourceOption,
...@@ -184,7 +187,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -184,7 +187,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
generator, output_types, output_shapes) generator, output_types, output_shapes)
# apply all kinds of transformations here, process the data # apply all kinds of transformations here, process the data
# even further if you want. # even further if you want.
dataset = dataset.prefetch(1) dataset = dataset.prefetch(2*10**3)
dataset = dataset.batch(10**3) dataset = dataset.batch(10**3)
images, labels, keys = dataset.make_one_shot_iterator().get_next() images, labels, keys = dataset.make_one_shot_iterator().get_next()
...@@ -192,34 +195,40 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -192,34 +195,40 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
return input_fn return input_fn
""" """
log_parameters(logger, ignore=('biofiles',)) log_parameters(logger, ignore=('biofiles',))
logger.debug("len(biofiles): %d", len(biofiles)) generator=None
assert len(biofiles), "biofiles are empty!" if not biofiles is None:
logger.debug("len(biofiles): %d", len(biofiles))
assert len(biofiles), "biofiles are empty!"
if array > 1:
start, end = indices(biofiles, array)
biofiles = biofiles[start:end]
if not biofiles is None and not database is None and not load_data is None:
if array > 1: # filter the existing files
start, end = indices(biofiles, array) paths = [
biofiles = biofiles[start:end] make_output_path(output_dir, f.make_path("", "")) for f in biofiles
]
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
# filter the existing files if len(biofiles) == 0:
paths = [ logger.warning(
make_output_path(output_dir, f.make_path("", "")) for f in biofiles "The biofiles are empty after checking for existing files.")
] return
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
if len(biofiles) == 0:
logger.warning(
"The biofiles are empty after checking for existing files.")
return
generator = BioGenerator( generator = BioGenerator(
database, database,
biofiles, biofiles,
load_data=load_data, load_data=load_data,
multiple_samples=multiple_samples) multiple_samples=multiple_samples)
predict_input_fn = bio_predict_input_fn(generator, generator.output_types, predict_input_fn = bio_predict_input_fn(generator, generator.output_types,
generator.output_shapes) generator.output_shapes)
else:
predict_input_fn = bio_predict_input_fn
if checkpoint_path: if checkpoint_path:
if os.path.isdir(checkpoint_path): if os.path.isdir(checkpoint_path):
...@@ -236,8 +245,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, ...@@ -236,8 +245,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
) )
logger.info("Saving the predictions of %d files in %s", len(generator), if not generator is None:
output_dir) logger.info("Saving the predictions of %d files in %s", len(generator),
output_dir)
else:
logger.info("Saving the predictions files in %s", output_dir)
pool = Pool() pool = Pool()
try: try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment