diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py
index 15b25fbaf31033b186b26e44c62e04a5849a5607..9f3cf7d8fb9721943fa409be1a4ad41accd0911e 100644
--- a/bob/learn/tensorflow/dataset/tfrecords.py
+++ b/bob/learn/tensorflow/dataset/tfrecords.py
@@ -181,7 +181,8 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames,
                                                random_contrast=False,
                                                random_saturation=False,
                                                random_rotate=False,
-                                               per_image_normalization=True):
+                                               per_image_normalization=True,
+                                               fixed_batch_size=False):
     """
     Dump random batches from a list of tf-record files and applies some image augmentation
 
@@ -229,6 +230,9 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames,
       per_image_normalization:
            Linearly scales image to have zero mean and unit norm.
 
+      fixed_batch_size:
+           If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
+
     """
 
     dataset = create_dataset_from_records_with_augmentation(
@@ -244,7 +248,13 @@ def shuffle_data_and_labels_image_augmentation(tfrecord_filenames,
         random_rotate=random_rotate,
         per_image_normalization=per_image_normalization)
 
-    dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+    # dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
+    dataset = dataset.shuffle(buffer_size)
+    if fixed_batch_size:
+        dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
+    else:
+        dataset = dataset.batch(batch_size)
+    dataset = dataset.repeat(epochs)
 
     data, labels, key = dataset.make_one_shot_iterator().get_next()
 
@@ -348,7 +358,8 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
                                              random_contrast=False,
                                              random_saturation=False,
                                              random_rotate=False,
-                                             per_image_normalization=True):
+                                             per_image_normalization=True,
+                                             fixed_batch_size=False):
     """
     Dump in order batches from a list of tf-record files
 
@@ -369,6 +380,9 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
        epochs:
            Number of epochs to be batched
 
+      fixed_batch_size:
+           If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
+
     """
 
     dataset = create_dataset_from_records_with_augmentation(
@@ -384,7 +398,12 @@ def batch_data_and_labels_image_augmentation(tfrecord_filenames,
         random_rotate=random_rotate,
         per_image_normalization=per_image_normalization)
 
-    dataset = dataset.batch(batch_size).repeat(epochs)
+    # dataset = dataset.batch(batch_size).repeat(epochs)
+    if fixed_batch_size:
+        dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
+    else:
+        dataset = dataset.batch(batch_size)
+    dataset = dataset.repeat(epochs)
 
     data, labels, key = dataset.make_one_shot_iterator().get_next()
     features = dict()
diff --git a/bob/learn/tensorflow/script/predict_bio.py b/bob/learn/tensorflow/script/predict_bio.py
index 5a38fccee4d1ab6df8af9a176030d72a7e774acb..65bcf619a86909ffd544135e19d581d0dc80c98b 100644
--- a/bob/learn/tensorflow/script/predict_bio.py
+++ b/bob/learn/tensorflow/script/predict_bio.py
@@ -243,9 +243,82 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
         load_data=load_data,
         multiple_samples=multiple_samples)
 
+    logger.info("Saving the predictions of %d files in %s", len(generator),
+                output_dir)
+
     predict_input_fn = bio_predict_input_fn(generator, generator.output_types,
                                             generator.output_shapes)
 
+    return predict(estimator, checkpoint_path, predict_input_fn, predict_keys,
+                   output_dir, hooks, video_container, **kwargs)
+
+
+@click.command(
+    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
+@click.option(
+    '--estimator',
+    '-e',
+    required=True,
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.estimator',
+    help='The estimator that will be evaluated.')
+@click.option(
+    '--predict-input-fn',
+    required=True,
+    cls=ResourceOption,
+    entry_point_group='bob.learn.tensorflow.biogenerator_input',
+    help='A callable with the signature of '
+         '`input_fn = predict_input_fn(generator, output_types, output_shapes)`'
+         ' The inputs are documented in :any:`tf.data.Dataset.from_generator`'
+         ' and the output should be a function with no arguments and is passed'
+         ' to :any:`tf.estimator.Estimator.predict`.')
+@click.option(
+    '--output-dir',
+    '-o',
+    required=True,
+    cls=ResourceOption,
+    help='The directory to save the predictions.')
+@click.option(
+    '--hooks',
+    cls=ResourceOption,
+    multiple=True,
+    entry_point_group='bob.learn.tensorflow.hook',
+    help='List of SessionRunHook subclass instances.')
+@click.option(
+    '--predict-keys',
+    '-k',
+    multiple=True,
+    default=None,
+    cls=ResourceOption,
+    help='List of `str`, name of the keys to predict. It is used if the '
+         '`EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used '
+         'then rest of the predictions will be filtered from the dictionary. '
+         'If `None`, returns all.')
+@click.option(
+    '--checkpoint-path',
+    '-c',
+    cls=ResourceOption,
+    help='Path of a specific checkpoint to predict. If `None`, the '
+         'latest checkpoint in `model_dir` is used. This can also '
+         'be a folder which contains a "checkpoint" file where the '
+         'latest checkpoint from inside this file will be used as '
+         'checkpoint_path.')
+@click.option(
+    '--video-container',
+    '-vc',
+    is_flag=True,
+    cls=ResourceOption,
+    help='If provided, the predictions will be written in FrameContainers from'
+    ' bob.bio.video. You need to install bob.bio.video as well.')
+@verbosity_option(cls=ResourceOption)
+def predict(estimator, checkpoint_path, predict_input_fn, predict_keys,
+            output_dir, hooks, video_container, **kwargs):
+    """Saves predictions or embeddings of tf.estimators.
+
+    This script works with data saved in tf-record format. This script works with
+    tensorflow 1.4 and above.
+    """
+
     if checkpoint_path:
         if os.path.isdir(checkpoint_path):
             ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
@@ -261,9 +334,6 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
         checkpoint_path=checkpoint_path,
     )
 
-    logger.info("Saving the predictions of %d files in %s", len(generator),
-                output_dir)
-
     pool = Pool()
     try:
         pred_buffer = defaultdict(list)
@@ -272,8 +342,7 @@ def predict_bio(estimator, database, biofiles, bio_predict_input_fn,
             # key is in bytes format in Python 3
             if sys.version_info >= (3, ):
                 key = key.decode(errors='replace')
-            prob = pred.get('probabilities', pred.get(
-                'embeddings', pred.get('predictions')))
+            prob = pred.get('probabilities', pred.get('embeddings', pred.get('predictions')))
             assert prob is not None
             pred_buffer[key].append(prob)
             if i == 0:
diff --git a/setup.py b/setup.py
index b75e4e679ef05ddffd821547a2d841fc533d52c2..11ad740e06f7b905f0f6ed2234d6208620a82b56 100644
--- a/setup.py
+++ b/setup.py
@@ -56,6 +56,7 @@ setup(
             'describe_tfrecord = bob.learn.tensorflow.script.db_to_tfrecords:describe_tfrecord',
             'eval = bob.learn.tensorflow.script.eval:eval',
             'predict_bio = bob.learn.tensorflow.script.predict_bio:predict_bio',
+            'predict = bob.learn.tensorflow.script.predict_bio:predict',
             'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer',
             'train = bob.learn.tensorflow.script.train:train',
             'train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:train_and_evaluate',