Commit ce620ce6 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

tfrecords: added fixed batch shuffling and allow lists of tfrecords

parent a11d5d41
Pipeline #29590 failed with stage
in 164 minutes and 41 seconds
......@@ -297,7 +297,7 @@ def create_dataset_from_records_with_augmentation(
if feature is None:
feature = DEFAULT_FEATURE
if os.path.isdir(tfrecord_filenames):
if isinstance(tfrecord_filenames, str) and os.path.isdir(tfrecord_filenames):
tfrecord_filenames = [
os.path.join(tfrecord_filenames, f) for f in os.listdir(tfrecord_filenames)
]
......@@ -339,6 +339,7 @@ def shuffle_data_and_labels_image_augmentation(
per_image_normalization=True,
random_gamma=False,
random_crop=False,
fixed_batch_size=False,
):
"""
Dump random batches from a list of tf-record files and applies some image augmentation
......@@ -384,9 +385,12 @@ def shuffle_data_and_labels_image_augmentation(
random_rotate:
Randomly rotate face images between -5 and 5 degrees
per_image_normalization:
per_image_normalization:
Linearly scales image to have zero mean and unit norm.
fixed_batch_size:
If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
"""
dataset = create_dataset_from_records_with_augmentation(
......@@ -405,7 +409,12 @@ def shuffle_data_and_labels_image_augmentation(
random_crop=random_crop,
)
dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs)
dataset = dataset.shuffle(buffer_size)
if fixed_batch_size:
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(epochs)
dataset = dataset.map(lambda d, l, k: ({"data": d, "key": k}, l))
......@@ -507,6 +516,7 @@ def batch_data_and_labels_image_augmentation(
per_image_normalization=True,
random_gamma=False,
random_crop=False,
fixed_batch_size=False,
):
"""
Dump in order batches from a list of tf-record files
......@@ -528,6 +538,9 @@ def batch_data_and_labels_image_augmentation(
epochs:
Number of epochs to be batched
fixed_batch_size:
If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
"""
dataset = create_dataset_from_records_with_augmentation(
......@@ -546,7 +559,11 @@ def batch_data_and_labels_image_augmentation(
random_crop=random_crop,
)
dataset = dataset.batch(batch_size).repeat(epochs)
if fixed_batch_size:
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(epochs)
data, labels, key = dataset.make_one_shot_iterator().get_next()
features = dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment