Skip to content
Snippets Groups Projects

Updates

Merged Tiago de Freitas Pereira requested to merge updates into master
1 file
+ 23
5
Compare changes
  • Side-by-side
  • Inline
@@ -16,15 +16,19 @@ class TFRecord(object):
Datashuffler that wraps tfrecords
"""
def __init__(self,filename_queue,
input_shape=[None, 28, 28, 1],
input_dtype="float32",
output_shape=[None, 28, 28, 1],
input_dtype=tf.float32,
batch_size=32,
seed=10,
prefetch_capacity=50,
prefetch_threads=5,
shuffle=True):
shuffle=True,
normalization=False,
random_flip=True,
random_crop=True
):
# Setting the seed for the pseudo random number generator
self.seed = seed
@@ -38,6 +42,7 @@ class TFRecord(object):
# Preparing the inputs
self.filename_queue = filename_queue
self.input_shape = tuple(input_shape)
self.output_shape = output_shape
# Prefetch variables
self.prefetch = True
@@ -48,6 +53,8 @@ class TFRecord(object):
self.label_ph = None
self.shuffle = shuffle
self.normalization = normalization
self.random_crop = random_crop
def __call__(self, element, from_queue=False):
"""
@@ -83,7 +90,7 @@ class TFRecord(object):
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/data'], tf.float32)
image = tf.decode_raw(features['train/data'], self.input_dtype)
#image = tf.decode_raw(features['train/data'], tf.uint8)
# Cast label data into int32
@@ -91,7 +98,19 @@ class TFRecord(object):
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
# Casting to float32
image = tf.cast(image, tf.float32)
if self.random_crop:
image = tf.image.resize_image_with_crop_or_pad(image, self.output_shape[1], self.output_shape[2])
# normalizing data
if self.normalization:
image = tf.image.per_image_standardization(image)
image.set_shape(tuple(self.output_shape[1:]))
if self.shuffle:
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
@@ -101,7 +120,6 @@ class TFRecord(object):
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch")
self.data_ph = data_ph
self.label_ph = label_ph
Loading