diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py index 84c21ab9f822bfe5311a7f1c2a44e7924daee048..faa0938a23db82957339149dcf914ce270a1208e 100644 --- a/bob/learn/tensorflow/dataset/tfrecords.py +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, features['key'] = key return features, labels + + +def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=1, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Dump in order batches from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + """ + + dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, + data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = dataset.batch(batch_size).repeat(epochs) + + data, labels, key = dataset.make_one_shot_iterator().get_next() + features = dict() + features['data'] = data + features['key'] = key + + return features, labels