Skip to content
Snippets Groups Projects
Commit c3fa06d7 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Add keyword argument to flag the shuffling mechanism

parent 171ede05
No related branches found
No related tags found
1 merge request!17Updates
......@@ -12,13 +12,19 @@ from bob.learn.tensorflow.datashuffler.Normalizer import Linear
class TFRecord(object):
"""
Datashuffler that wraps tfrecords
"""
def __init__(self,filename_queue,
input_shape=[None, 28, 28, 1],
input_dtype="float32",
batch_size=32,
seed=10,
prefetch_capacity=50,
prefetch_threads=5):
prefetch_threads=5,
shuffle=True):
# Setting the seed for the pseudo random number generator
self.seed = seed
......@@ -40,6 +46,8 @@ class TFRecord(object):
self.data_ph = None
self.label_ph = None
self.shuffle = shuffle
def __call__(self, element, from_queue=False):
"""
......@@ -83,10 +91,14 @@ class TFRecord(object):
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
if self.shuffle:
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
else:
data_ph, label_ph = tf.train.batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch")
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
self.data_ph = data_ph
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment