diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py index 29f8343af22f13455489ce1626c1de255e18c5e3..8d8f8a6fdda0eeb9b07898968845fa6ee00bcba0 100644 --- a/bob/learn/tensorflow/script/predict_bio.py +++ b/bob/learn/tensorflow/script/predict_bio.py @@ -9,8 +9,11 @@ import sys import logging import click from bob.extension.scripts.click_helper import ( - verbosity_option, ConfigCommand, ResourceOption, log_parameters) -from multiprocessing import Pool + verbosity_option, + ConfigCommand, + ResourceOption, + log_parameters, +) from collections import defaultdict import numpy as np import tensorflow as tf @@ -18,6 +21,7 @@ from bob.io.base import create_directories_safe from bob.bio.base.utils import save from bob.bio.base.tools.grid import indices from bob.learn.tensorflow.dataset.bio import BioGenerator + try: import bob.bio.video except ModuleNotFoundError: @@ -42,7 +46,7 @@ def make_output_path(output_dir, key): str The path for the provided key. """ - return os.path.join(output_dir, key + '.hdf5') + return os.path.join(output_dir, key + ".hdf5") def non_existing_files(paths, force=False): @@ -69,110 +73,133 @@ def save_predictions(output_dir, key, pred_buffer, video_container): save(data, outpath) -@click.command( - entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) +@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand) @click.option( - '--estimator', - '-e', + "--estimator", + "-e", required=True, cls=ResourceOption, - entry_point_group='bob.learn.tensorflow.estimator', - help='The estimator that will be evaluated.') + entry_point_group="bob.learn.tensorflow.estimator", + help="The estimator that will be evaluated.", +) @click.option( - '--database', - '-d', + "--database", + "-d", required=True, cls=ResourceOption, - entry_point_group='bob.bio.database', - help='A bio database. Its original_directory must point to the correct ' - 'path.') + entry_point_group="bob.bio.database", + help="A bio database. Its original_directory must point to the correct " "path.", +) @click.option( - '--biofiles', + "--biofiles", required=True, cls=ResourceOption, - help='The list of the bio files. You can only provide this through config ' - 'files.') + help="The list of the bio files. You can only provide this through config " + "files.", +) @click.option( - '--bio-predict-input-fn', + "--bio-predict-input-fn", required=True, cls=ResourceOption, - entry_point_group='bob.learn.tensorflow.biogenerator_input', - help='A callable with the signature of ' - '`input_fn = bio_predict_input_fn(generator, output_types, output_shapes)`' - ' The inputs are documented in :any:`tf.data.Dataset.from_generator`' - ' and the output should be a function with no arguments and is passed' - ' to :any:`tf.estimator.Estimator.predict`.') + entry_point_group="bob.learn.tensorflow.biogenerator_input", + help="A callable with the signature of " + "`input_fn = bio_predict_input_fn(generator, output_types, output_shapes)`" + " The inputs are documented in :any:`tf.data.Dataset.from_generator`" + " and the output should be a function with no arguments and is passed" + " to :any:`tf.estimator.Estimator.predict`.", +) @click.option( - '--output-dir', - '-o', + "--output-dir", + "-o", required=True, cls=ResourceOption, - help='The directory to save the predictions.') + help="The directory to save the predictions.", +) @click.option( - '--load-data', + "--load-data", cls=ResourceOption, - entry_point_group='bob.learn.tensorflow.load_data', - help='A callable with the signature of ' - '``data = load_data(database, biofile)``. ' - ':any:`bob.bio.base.read_original_data` is used by default.') + entry_point_group="bob.learn.tensorflow.load_data", + help="A callable with the signature of " + "``data = load_data(database, biofile)``. " + ":any:`bob.bio.base.read_original_data` is used by default.", +) @click.option( - '--hooks', + "--hooks", cls=ResourceOption, multiple=True, - entry_point_group='bob.learn.tensorflow.hook', - help='List of SessionRunHook subclass instances.') + entry_point_group="bob.learn.tensorflow.hook", + help="List of SessionRunHook subclass instances.", +) @click.option( - '--predict-keys', - '-k', + "--predict-keys", + "-k", multiple=True, default=None, cls=ResourceOption, - help='List of `str`, name of the keys to predict. It is used if the ' - '`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used ' - 'then rest of the predictions will be filtered from the dictionary. ' - 'If `None`, returns all.') + help="List of `str`, name of the keys to predict. It is used if the " + "`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used " + "then rest of the predictions will be filtered from the dictionary. " + "If `None`, returns all.", +) @click.option( - '--checkpoint-path', - '-c', + "--checkpoint-path", + "-c", cls=ResourceOption, - help='Path of a specific checkpoint to predict. If `None`, the ' - 'latest checkpoint in `model_dir` is used. This can also ' + help="Path of a specific checkpoint to predict. If `None`, the " + "latest checkpoint in `model_dir` is used. This can also " 'be a folder which contains a "checkpoint" file where the ' - 'latest checkpoint from inside this file will be used as ' - 'checkpoint_path.') + "latest checkpoint from inside this file will be used as " + "checkpoint_path.", +) @click.option( - '--multiple-samples', - '-m', + "--multiple-samples", + "-m", is_flag=True, cls=ResourceOption, - help='If provided, it assumes that the db interface returns ' - 'several samples from a biofile. This option can be used ' - 'when you are working with videos.') + help="If provided, it assumes that the db interface returns " + "several samples from a biofile. This option can be used " + "when you are working with videos.", +) @click.option( - '--array', - '-t', + "--array", + "-t", type=click.INT, default=1, cls=ResourceOption, - help='Use this option alongside gridtk to submit this script as ' - 'an array job.') + help="Use this option alongside gridtk to submit this script as " "an array job.", +) @click.option( - '--force', - '-f', + "--force", + "-f", is_flag=True, cls=ResourceOption, - help='Whether to overwrite existing predictions.') + help="Whether to overwrite existing predictions.", +) @click.option( - '--video-container', - '-vc', + "--video-container", + "-vc", is_flag=True, cls=ResourceOption, - help='If provided, the predictions will be written in FrameContainers from' - ' bob.bio.video. You need to install bob.bio.video as well.') + help="If provided, the predictions will be written in FrameContainers from" + " bob.bio.video. You need to install bob.bio.video as well.", +) @verbosity_option(cls=ResourceOption) -def predict_bio(estimator, database, biofiles, bio_predict_input_fn, - output_dir, load_data, hooks, predict_keys, checkpoint_path, - multiple_samples, array, force, video_container, **kwargs): +def predict_bio( + estimator, + database, + biofiles, + bio_predict_input_fn, + output_dir, + load_data, + hooks, + predict_keys, + checkpoint_path, + multiple_samples, + array, + force, + video_container, + **kwargs +): """Saves predictions or embeddings of tf.estimators. This script works with bob.bio.base databases. This script works with @@ -209,7 +236,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, return {'data': images, 'key': keys}, labels return input_fn """ - log_parameters(logger, ignore=('biofiles', )) + log_parameters(logger, ignore=("biofiles",)) logger.debug("len(biofiles): %d", len(biofiles)) assert len(biofiles), "biofiles are empty!" @@ -219,98 +246,125 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, biofiles = biofiles[start:end] # filter the existing files - paths = [ - make_output_path(output_dir, f.make_path("", "")) for f in biofiles - ] + paths = [make_output_path(output_dir, f.make_path("", "")) for f in biofiles] indexes = non_existing_files(paths, force) biofiles = [biofiles[i] for i in indexes] if len(biofiles) == 0: - logger.warning( - "The biofiles are empty after checking for existing files.") + logger.warning("The biofiles are empty after checking for existing files.") return generator = BioGenerator( - database, - biofiles, - load_data=load_data, - multiple_samples=multiple_samples) + database, biofiles, load_data=load_data, multiple_samples=multiple_samples + ) - predict_input_fn = bio_predict_input_fn(generator, generator.output_types, - generator.output_shapes) + predict_input_fn = bio_predict_input_fn( + generator, generator.output_types, generator.output_shapes + ) - logger.info("Saving the predictions of %d files in %s", len(generator), - output_dir) + logger.info("Saving the predictions of %d files in %s", len(generator), output_dir) generic_predict( - estimator, predict_input_fn, output_dir, predict_keys=predict_keys, - checkpoint_path=checkpoint_path, hooks=hooks, - video_container=video_container) + estimator=estimator, + predict_input_fn=predict_input_fn, + output_dir=output_dir, + predict_keys=predict_keys, + checkpoint_path=checkpoint_path, + hooks=hooks, + video_container=video_container, + ) -@click.command( - entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) +@click.command(entry_point_group="bob.learn.tensorflow.config", cls=ConfigCommand) @click.option( - '--estimator', - '-e', + "--estimator", + "-e", required=True, cls=ResourceOption, - entry_point_group='bob.learn.tensorflow.estimator', - help='The estimator that will be evaluated.') + entry_point_group="bob.learn.tensorflow.estimator", + help="The estimator that will be evaluated.", +) @click.option( - '--predict-input-fn', + "--predict-input-fn", required=True, cls=ResourceOption, - entry_point_group='bob.learn.tensorflow.input_fn', - help='A callable with no arguments which will be used in estimator.predict.') + entry_point_group="bob.learn.tensorflow.input_fn", + help="A callable with no arguments which will be used in estimator.predict.", +) @click.option( - '--output-dir', - '-o', + "--output-dir", + "-o", required=True, cls=ResourceOption, - help='The directory to save the predictions.') + help="The directory to save the predictions.", +) @click.option( - '--predict-keys', - '-k', + "--predict-keys", + "-k", multiple=True, default=None, cls=ResourceOption, - help='List of `str`, name of the keys to predict. It is used if the ' - '`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used ' - 'then rest of the predictions will be filtered from the dictionary. ' - 'If `None`, returns all.') + help="List of `str`, name of the keys to predict. It is used if the " + "`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used " + "then rest of the predictions will be filtered from the dictionary. " + "If `None`, returns all.", +) @click.option( - '--checkpoint-path', - '-c', + "--checkpoint-path", + "-c", cls=ResourceOption, - help='Path of a specific checkpoint to predict. If `None`, the ' - 'latest checkpoint in `model_dir` is used. This can also ' + help="Path of a specific checkpoint to predict. If `None`, the " + "latest checkpoint in `model_dir` is used. This can also " 'be a folder which contains a "checkpoint" file where the ' - 'latest checkpoint from inside this file will be used as ' - 'checkpoint_path.') + "latest checkpoint from inside this file will be used as " + "checkpoint_path.", +) @click.option( - '--hooks', + "--hooks", cls=ResourceOption, multiple=True, - entry_point_group='bob.learn.tensorflow.hook', - help='List of SessionRunHook subclass instances.') + entry_point_group="bob.learn.tensorflow.hook", + help="List of SessionRunHook subclass instances.", +) @click.option( - '--video-container', - '-vc', + "--video-container", + "-vc", is_flag=True, cls=ResourceOption, - help='If provided, the predictions will be written in FrameContainers from' - ' bob.bio.video. You need to install bob.bio.video as well.') + help="If provided, the predictions will be written in FrameContainers from" + " bob.bio.video. You need to install bob.bio.video as well.", +) @verbosity_option(cls=ResourceOption) -def predict(estimator, predict_input_fn, output_dir, predict_keys, - checkpoint_path, hooks, video_container, **kwargs): +def predict( + estimator, + predict_input_fn, + output_dir, + predict_keys, + checkpoint_path, + hooks, + video_container, + **kwargs +): generic_predict( - estimator, predict_input_fn, output_dir, predict_keys, - checkpoint_path, hooks, video_container) + estimator=estimator, + predict_input_fn=predict_input_fn, + output_dir=output_dir, + predict_keys=predict_keys, + checkpoint_path=checkpoint_path, + hooks=hooks, + video_container=video_container, + ) -def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, - checkpoint_path=None, hooks=None, video_container=False): +def generic_predict( + estimator, + predict_input_fn, + output_dir, + predict_keys=None, + checkpoint_path=None, + hooks=None, + video_container=False, +): # if the checkpoint_path is a directory, pick the latest checkpoint from # that directory if checkpoint_path: @@ -333,16 +387,21 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, import bob.bio.video except ModuleNotFoundError: raise click.ClickException( - 'Could not import bob.bio.video. Have you installed it?') + "Could not import bob.bio.video. Have you installed it?" + ) pred_buffer = defaultdict(list) for i, pred in enumerate(predictions): - key = pred['key'] + key = pred["key"] # key is in bytes format in Python 3 - if sys.version_info >= (3, ): - key = key.decode(errors='replace') - prob = pred.get('probabilities', pred.get( - 'embeddings', pred.get('predictions'))) + if sys.version_info >= (3,): + key = key.decode(errors="replace") + if predict_keys: + prob = pred[predict_keys[0]] + else: + prob = pred.get( + "probabilities", pred.get("embeddings", pred.get("predictions")) + ) assert prob is not None pred_buffer[key].append(prob) if i == 0: @@ -350,13 +409,17 @@ def generic_predict(estimator, predict_input_fn, output_dir, predict_keys=None, if last_key == key: continue else: - save_predictions( - output_dir, last_key, pred_buffer, video_container) + save_predictions(output_dir, last_key, pred_buffer, video_container) + # delete saved data so we don't run out of RAM + del pred_buffer[last_key] + # start saving this new key last_key = key + try: key # save the final returned key as well: save_predictions(output_dir, key, pred_buffer, video_container) except UnboundLocalError: # if the input_fn was empty and hence key is not defined + click.echo("predict_input_fn returned no samples.") pass