Skip to content
Snippets Groups Projects
Commit 3959fa6f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

rename fixed_batch_size to drop_remainder

parent ce620ce6
No related branches found
No related tags found
1 merge request!75A lot of new features
Pipeline #29595 failed
......@@ -339,7 +339,7 @@ def shuffle_data_and_labels_image_augmentation(
per_image_normalization=True,
random_gamma=False,
random_crop=False,
fixed_batch_size=False,
drop_remainder=False,
):
"""
Dump random batches from a list of tf-record files and applies some image augmentation
......@@ -388,7 +388,7 @@ def shuffle_data_and_labels_image_augmentation(
per_image_normalization:
Linearly scales image to have zero mean and unit norm.
fixed_batch_size:
drop_remainder:
If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
"""
......@@ -410,10 +410,7 @@ def shuffle_data_and_labels_image_augmentation(
)
dataset = dataset.shuffle(buffer_size)
if fixed_batch_size:
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.repeat(epochs)
dataset = dataset.map(lambda d, l, k: ({"data": d, "key": k}, l))
......@@ -516,7 +513,7 @@ def batch_data_and_labels_image_augmentation(
per_image_normalization=True,
random_gamma=False,
random_crop=False,
fixed_batch_size=False,
drop_remainder=False,
):
"""
Dump in order batches from a list of tf-record files
......@@ -538,7 +535,7 @@ def batch_data_and_labels_image_augmentation(
epochs:
Number of epochs to be batched
fixed_batch_size:
drop_remainder:
If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
"""
......@@ -559,10 +556,7 @@ def batch_data_and_labels_image_augmentation(
random_crop=random_crop,
)
if fixed_batch_size:
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.repeat(epochs)
data, labels, key = dataset.make_one_shot_iterator().get_next()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment