From 0c76788ed637741a720b9ad313074a1371cfbe1f Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Tue, 31 Oct 2017 09:30:43 +0100 Subject: [PATCH] Created a function that batches a tf-record in order and apply data augmentation --- bob/learn/tensorflow/dataset/tfrecords.py | 51 +++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py index 84c21ab9..faa0938a 100644 --- a/bob/learn/tensorflow/dataset/tfrecords.py +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, features['key'] = key return features, labels + + +def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=1, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Dump in order batches from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + """ + + dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, + data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = dataset.batch(batch_size).repeat(epochs) + + data, labels, key = dataset.make_one_shot_iterator().get_next() + features = dict() + features['data'] = data + features['key'] = key + + return features, labels -- GitLab