Skip to content
Snippets Groups Projects

Updates

Merged Tiago de Freitas Pereira requested to merge updates into master
1 file
+ 16
4
Compare changes
  • Side-by-side
  • Inline
@@ -12,13 +12,19 @@ from bob.learn.tensorflow.datashuffler.Normalizer import Linear
@@ -12,13 +12,19 @@ from bob.learn.tensorflow.datashuffler.Normalizer import Linear
class TFRecord(object):
class TFRecord(object):
 
"""
 
Datashuffler that wraps tfrecords
 
"""
 
 
def __init__(self,filename_queue,
def __init__(self,filename_queue,
input_shape=[None, 28, 28, 1],
input_shape=[None, 28, 28, 1],
input_dtype="float32",
input_dtype="float32",
batch_size=32,
batch_size=32,
seed=10,
seed=10,
prefetch_capacity=50,
prefetch_capacity=50,
prefetch_threads=5):
prefetch_threads=5,
 
shuffle=True):
# Setting the seed for the pseudo random number generator
# Setting the seed for the pseudo random number generator
self.seed = seed
self.seed = seed
@@ -40,6 +46,8 @@ class TFRecord(object):
@@ -40,6 +46,8 @@ class TFRecord(object):
self.data_ph = None
self.data_ph = None
self.label_ph = None
self.label_ph = None
 
 
self.shuffle = shuffle
def __call__(self, element, from_queue=False):
def __call__(self, element, from_queue=False):
"""
"""
@@ -83,10 +91,14 @@ class TFRecord(object):
@@ -83,10 +91,14 @@ class TFRecord(object):
# Reshape image data into the original shape
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
image = tf.reshape(image, self.input_shape[1:])
 
if self.shuffle:
 
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
 
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
 
min_after_dequeue=1, name="shuffle_batch")
 
else:
 
data_ph, label_ph = tf.train.batch([image, label], batch_size=self.batch_size,
 
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch")
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
self.data_ph = data_ph
self.data_ph = data_ph
Loading