Skip to content
Snippets Groups Projects

Integrate the structure of train, eval, and predict

Closed Saeed SARFJOO requested to merge integrate_structures into master
1 unresolved thread
1 file
+ 45
33
Compare changes
  • Side-by-side
  • Inline
@@ -70,17 +70,17 @@ def save_predictions(pool, output_dir, key, pred_buffer):
@click.option(
'--database',
'-d',
required=True,
default=None,
cls=ResourceOption,
entry_point_group='bob.bio.database',
help='A bio database. Its original_directory must point to the correct '
'path.')
'path. If `None` the `bio_predict_input_fn` must have infos about database')
@click.option(
'--biofiles',
required=True,
default=None,
cls=ResourceOption,
help='The list of the bio files. You can only provide this through config '
'files.')
'files. If `None` `bio_predict_input_fn` must have infos about biofiles')
@click.option(
'--bio-predict-input-fn',
required=True,
@@ -90,7 +90,8 @@ def save_predictions(pool, output_dir, key, pred_buffer):
'``input_fn = bio_predict_input_fn(generator, output_types, output_shapes)``'
' The inputs are documented in :any:`tf.data.Dataset.from_generator`'
' and the output should be a function with no arguments and is passed'
' to :any:`tf.estimator.Estimator.predict`.')
' to :any:`tf.estimator.Estimator.predict`. If no input argument is given'
' bio_predict_input_fn is passed to :any:`tf.estimator.Estimator.predict`.')
@click.option(
'--output-dir',
'-o',
@@ -99,11 +100,13 @@ def save_predictions(pool, output_dir, key, pred_buffer):
help='The directory to save the predictions.')
@click.option(
'--load-data',
default=None,
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.load_data',
help='A callable with the signature of '
'``data = load_data(database, biofile)``. '
':any:`bob.bio.base.read_original_data` is used by default.')
':any:`bob.bio.base.read_original_data` is used by default.'
' If `None` `bio_predict_input_fn` must have infos about load_data')
@click.option(
'--hooks',
cls=ResourceOption,
@@ -184,7 +187,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
generator, output_types, output_shapes)
# apply all kinds of transformations here, process the data
# even further if you want.
dataset = dataset.prefetch(1)
dataset = dataset.prefetch(2*10**3)
dataset = dataset.batch(10**3)
images, labels, keys = dataset.make_one_shot_iterator().get_next()
@@ -192,34 +195,40 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
return input_fn
"""
log_parameters(logger, ignore=('biofiles',))
logger.debug("len(biofiles): %d", len(biofiles))
assert len(biofiles), "biofiles are empty!"
generator=None
if not biofiles is None:
logger.debug("len(biofiles): %d", len(biofiles))
assert len(biofiles), "biofiles are empty!"
if array > 1:
start, end = indices(biofiles, array)
biofiles = biofiles[start:end]
if not biofiles is None and not database is None and not load_data is None:
if array > 1:
start, end = indices(biofiles, array)
biofiles = biofiles[start:end]
# filter the existing files
paths = [
make_output_path(output_dir, f.make_path("", "")) for f in biofiles
]
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
# filter the existing files
paths = [
make_output_path(output_dir, f.make_path("", "")) for f in biofiles
]
indexes = non_existing_files(paths, force)
biofiles = [biofiles[i] for i in indexes]
if len(biofiles) == 0:
logger.warning(
"The biofiles are empty after checking for existing files.")
return
if len(biofiles) == 0:
logger.warning(
"The biofiles are empty after checking for existing files.")
return
generator = BioGenerator(
database,
biofiles,
load_data=load_data,
multiple_samples=multiple_samples)
generator = BioGenerator(
database,
biofiles,
load_data=load_data,
multiple_samples=multiple_samples)
predict_input_fn = bio_predict_input_fn(generator, generator.output_types,
generator.output_shapes)
predict_input_fn = bio_predict_input_fn(generator, generator.output_types,
generator.output_shapes)
else:
predict_input_fn = bio_predict_input_fn
if checkpoint_path:
if os.path.isdir(checkpoint_path):
@@ -236,8 +245,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
checkpoint_path=checkpoint_path,
)
logger.info("Saving the predictions of %d files in %s", len(generator),
output_dir)
if not generator is None:
logger.info("Saving the predictions of %d files in %s", len(generator),
output_dir)
else:
logger.info("Saving the predictions files in %s", output_dir)
pool = Pool()
try:
Loading