diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py index 2e06503eb9266e7fe93b3da83f762fee97d239b6..c0069023bd3c884aa0125e32d4b24c68475418a7 100644 --- a/bob/learn/tensorflow/script/predict_bio.py +++ b/bob/learn/tensorflow/script/predict_bio.py @@ -70,17 +70,17 @@ def save_predictions(pool, output_dir, key, pred_buffer): @click.option( '--database', '-d', - required=True, + default=None, cls=ResourceOption, entry_point_group='bob.bio.database', help='A bio database. Its original_directory must point to the correct ' - 'path.') + 'path. If `None` the `bio_predict_input_fn` must have infos about database') @click.option( '--biofiles', - required=True, + default=None, cls=ResourceOption, help='The list of the bio files. You can only provide this through config ' - 'files.') + 'files. If `None` `bio_predict_input_fn` must have infos about biofiles') @click.option( '--bio-predict-input-fn', required=True, @@ -90,7 +90,8 @@ def save_predictions(pool, output_dir, key, pred_buffer): '``input_fn = bio_predict_input_fn(generator, output_types, output_shapes)``' ' The inputs are documented in :any:`tf.data.Dataset.from_generator`' ' and the output should be a function with no arguments and is passed' - ' to :any:`tf.estimator.Estimator.predict`.') + ' to :any:`tf.estimator.Estimator.predict`. If no input argument is given' + ' bio_predict_input_fn is passed to :any:`tf.estimator.Estimator.predict`.') @click.option( '--output-dir', '-o', @@ -99,11 +100,13 @@ def save_predictions(pool, output_dir, key, pred_buffer): help='The directory to save the predictions.') @click.option( '--load-data', + default=None, cls=ResourceOption, entry_point_group='bob.learn.tensorflow.load_data', help='A callable with the signature of ' '``data = load_data(database, biofile)``. ' - ':any:`bob.bio.base.read_original_data` is used by default.') + ':any:`bob.bio.base.read_original_data` is used by default.' + ' If `None` `bio_predict_input_fn` must have infos about load_data') @click.option( '--hooks', cls=ResourceOption, @@ -184,7 +187,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, generator, output_types, output_shapes) # apply all kinds of transformations here, process the data # even further if you want. - dataset = dataset.prefetch(1) + dataset = dataset.prefetch(2*10**3) dataset = dataset.batch(10**3) images, labels, keys = dataset.make_one_shot_iterator().get_next() @@ -192,34 +195,40 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, return input_fn """ log_parameters(logger, ignore=('biofiles',)) - logger.debug("len(biofiles): %d", len(biofiles)) - - assert len(biofiles), "biofiles are empty!" + generator=None + + if not biofiles is None: + logger.debug("len(biofiles): %d", len(biofiles)) + assert len(biofiles), "biofiles are empty!" + if array > 1: + start, end = indices(biofiles, array) + biofiles = biofiles[start:end] + + if not biofiles is None and not database is None and not load_data is None: - if array > 1: - start, end = indices(biofiles, array) - biofiles = biofiles[start:end] + # filter the existing files + paths = [ + make_output_path(output_dir, f.make_path("", "")) for f in biofiles + ] + indexes = non_existing_files(paths, force) + biofiles = [biofiles[i] for i in indexes] - # filter the existing files - paths = [ - make_output_path(output_dir, f.make_path("", "")) for f in biofiles - ] - indexes = non_existing_files(paths, force) - biofiles = [biofiles[i] for i in indexes] - - if len(biofiles) == 0: - logger.warning( - "The biofiles are empty after checking for existing files.") - return + if len(biofiles) == 0: + logger.warning( + "The biofiles are empty after checking for existing files.") + return - generator = BioGenerator( - database, - biofiles, - load_data=load_data, - multiple_samples=multiple_samples) + generator = BioGenerator( + database, + biofiles, + load_data=load_data, + multiple_samples=multiple_samples) - predict_input_fn = bio_predict_input_fn(generator, generator.output_types, - generator.output_shapes) + predict_input_fn = bio_predict_input_fn(generator, generator.output_types, + generator.output_shapes) + else: + predict_input_fn = bio_predict_input_fn + if checkpoint_path: if os.path.isdir(checkpoint_path): @@ -236,8 +245,11 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, checkpoint_path=checkpoint_path, ) - logger.info("Saving the predictions of %d files in %s", len(generator), - output_dir) + if not generator is None: + logger.info("Saving the predictions of %d files in %s", len(generator), + output_dir) + else: + logger.info("Saving the predictions files in %s", output_dir) pool = Pool() try: