Skip to content
Snippets Groups Projects

Created a function that batches a tf-record in order and apply data augmentation

Merged Tiago de Freitas Pereira requested to merge batch-tfrecord into master
1 file
+ 51
0
Compare changes
  • Side-by-side
  • Inline
@@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
features['key'] = key
return features, labels
def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type,
batch_size, epochs=1,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Dump in order batches from a list of tf-record files
**Parameters**
tfrecord_filenames:
List containing the tf-record paths
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
batch_size:
Size of the batch
epochs:
Number of epochs to be batched
"""
dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape,
data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
dataset = dataset.batch(batch_size).repeat(epochs)
data, labels, key = dataset.make_one_shot_iterator().get_next()
features = dict()
features['data'] = data
features['key'] = key
return features, labels
Loading