diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py index 15b25fbaf31033b186b26e44c62e04a5849a5607..9f3cf7d8fb9721943fa409be1a4ad41accd0911e 100644 --- a/bob/learn/tensorflow/dataset/tfrecords.py +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -181,7 +181,8 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, random_contrast=False, random_saturation=False, random_rotate=False, - per_image_normalization=True): + per_image_normalization=True, + fixed_batch_size=False): """ Dump random batches from a list of tf-record files and applies some image augmentation @@ -229,6 +230,9 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, per_image_normalization: Linearly scales image to have zero mean and unit norm. + fixed_batch_size: + If True, the last remaining batch that has smaller size than `batch_size' will be dropped. + """ dataset = create_dataset_from_records_with_augmentation( @@ -244,7 +248,13 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames, random_rotate=random_rotate, per_image_normalization=per_image_normalization) - dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + # dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + dataset = dataset.shuffle(buffer_size) + if fixed_batch_size: + dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + dataset = dataset.repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() @@ -348,7 +358,8 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames, random_contrast=False, random_saturation=False, random_rotate=False, - per_image_normalization=True): + per_image_normalization=True, + fixed_batch_size=False): """ Dump in order batches from a list of tf-record files @@ -369,6 +380,9 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames, epochs: Number of epochs to be batched + fixed_batch_size: + If True, the last remaining batch that has smaller size than `batch_size' will be dropped. + """ dataset = create_dataset_from_records_with_augmentation( @@ -384,7 +398,12 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames, random_rotate=random_rotate, per_image_normalization=per_image_normalization) - dataset = dataset.batch(batch_size).repeat(epochs) + # dataset = dataset.batch(batch_size).repeat(epochs) + if fixed_batch_size: + dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + dataset = dataset.repeat(epochs) data, labels, key = dataset.make_one_shot_iterator().get_next() features = dict() diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py index 5a38fccee4d1ab6df8af9a176030d72a7e774acb..65bcf619a86909ffd544135e19d581d0dc80c98b 100644 --- a/bob/learn/tensorflow/script/predict_bio.py +++ b/bob/learn/tensorflow/script/predict_bio.py @@ -243,9 +243,82 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, load_data=load_data, multiple_samples=multiple_samples) + logger.info("Saving the predictions of %d files in %s", len(generator), + output_dir) + predict_input_fn = bio_predict_input_fn(generator, generator.output_types, generator.output_shapes) + return predict(estimator, checkpoint_path, predict_input_fn, predict_keys, + output_dir, hooks, video_container, **kwargs) + + +@click.command( + entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) +@click.option( + '--estimator', + '-e', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.estimator', + help='The estimator that will be evaluated.') +@click.option( + '--predict-input-fn', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.biogenerator_input', + help='A callable with the signature of ' + '`input_fn = predict_input_fn(generator, output_types, output_shapes)`' + ' The inputs are documented in :any:`tf.data.Dataset.from_generator`' + ' and the output should be a function with no arguments and is passed' + ' to :any:`tf.estimator.Estimator.predict`.') +@click.option( + '--output-dir', + '-o', + required=True, + cls=ResourceOption, + help='The directory to save the predictions.') +@click.option( + '--hooks', + cls=ResourceOption, + multiple=True, + entry_point_group='bob.learn.tensorflow.hook', + help='List of SessionRunHook subclass instances.') +@click.option( + '--predict-keys', + '-k', + multiple=True, + default=None, + cls=ResourceOption, + help='List of `str`, name of the keys to predict. It is used if the ' + '`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used ' + 'then rest of the predictions will be filtered from the dictionary. ' + 'If `None`, returns all.') +@click.option( + '--checkpoint-path', + '-c', + cls=ResourceOption, + help='Path of a specific checkpoint to predict. If `None`, the ' + 'latest checkpoint in `model_dir` is used. This can also ' + 'be a folder which contains a "checkpoint" file where the ' + 'latest checkpoint from inside this file will be used as ' + 'checkpoint_path.') +@click.option( + '--video-container', + '-vc', + is_flag=True, + cls=ResourceOption, + help='If provided, the predictions will be written in FrameContainers from' + ' bob.bio.video. You need to install bob.bio.video as well.') +@verbosity_option(cls=ResourceOption) +def predict(estimator, checkpoint_path, predict_input_fn, predict_keys, + output_dir, hooks, video_container, **kwargs): + """Saves predictions or embeddings of tf.estimators. + + This script works with data saved in tf-record format. This script works with + tensorflow 1.4 and above. + """ + if checkpoint_path: if os.path.isdir(checkpoint_path): ckpt = tf.train.get_checkpoint_state(estimator.model_dir) @@ -261,9 +334,6 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, checkpoint_path=checkpoint_path, ) - logger.info("Saving the predictions of %d files in %s", len(generator), - output_dir) - pool = Pool() try: pred_buffer = defaultdict(list) @@ -272,8 +342,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn, # key is in bytes format in Python 3 if sys.version_info >= (3, ): key = key.decode(errors='replace') - prob = pred.get('probabilities', pred.get( - 'embeddings', pred.get('predictions'))) + prob = pred.get('probabilities', pred.get('embeddings', pred.get('predictions'))) assert prob is not None pred_buffer[key].append(prob) if i == 0: diff --git a/setup.py b/setup.py index b75e4e679ef05ddffd821547a2d841fc533d52c2..11ad740e06f7b905f0f6ed2234d6208620a82b56 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ setup( 'describe_tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord', 'eval = bob.learn.tensorflow.script.eval:eval', 'predict_bio = bob.learn.tensorflow.script.predict_bio:predict_bio', + 'predict = bob.learn.tensorflow.script.predict_bio:predict', 'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer', 'train = bob.learn.tensorflow.script.train:train', 'train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:train_and_evaluate',