Skip to content
Snippets Groups Projects
Commit 0c76788e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created a function that batches a tf-record in order and apply data augmentation

parent b9c51932
No related branches found
No related tags found
1 merge request!29Created a function that batches a tf-record in order and apply data augmentation
......@@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
features['key'] = key
return features, labels
def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type,
batch_size, epochs=1,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Dump in order batches from a list of tf-record files
**Parameters**
tfrecord_filenames:
List containing the tf-record paths
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
batch_size:
Size of the batch
epochs:
Number of epochs to be batched
"""
dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape,
data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
dataset = dataset.batch(batch_size).repeat(epochs)
data, labels, key = dataset.make_one_shot_iterator().get_next()
features = dict()
features['data'] = data
features['key'] = key
return features, labels
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment