Skip to content
Snippets Groups Projects

Removed code duplication between TFRecord and TFRecordImage.

Merged Samuel GAIST requested to merge cleanup_tfrecord_code_duplication into master
2 files
+ 52
87
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -11,7 +11,7 @@ import numpy
class TFRecord(object):
"""
Generic datashuffler that wraps the batching using tfrecords.
**Parameters**
filename_queue: Tensorflow producer
@@ -19,7 +19,7 @@ class TFRecord(object):
input_dtype: Type of the raw data saved in the tf record
batch_size: Size of the batch
seed: Seed
prefetch_capacity: Capacity of the bucket for prefetching
prefetch_capacity: Capacity of the bucket for prefetching
prefetch_threads: Number of threads in the prefetching
"""
def __init__(self,filename_queue,
@@ -36,7 +36,7 @@ class TFRecord(object):
self.input_dtype = input_dtype
# TODO: Check if the bacth size is higher than the input data
# TODO: Check if the batch size is higher than the input data
self.batch_size = batch_size
# Preparing the inputs
@@ -47,14 +47,14 @@ class TFRecord(object):
self.prefetch = True
self.prefetch_capacity = prefetch_capacity
self.prefetch_threads = prefetch_threads
self.data_ph = None
self.label_ph = None
def __call__(self, element, from_queue=False):
"""
Return the necessary placeholder
"""
if not element in ["data", "label"]:
@@ -70,35 +70,47 @@ class TFRecord(object):
return self.label_ph
def create_placeholders(self):
def __load_features(self):
"""
Load features from queue
"""
feature = {'train/data': tf.FixedLenFeature([], tf.string),
'train/label': tf.FixedLenFeature([], tf.int64)}
# Define a reader and read the next record
reader = tf.TFRecordReader()
_, serialized_example = reader.read(self.filename_queue)
# Decode the record read by the reader
features = tf.parse_single_example(serialized_example, features=feature)
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/data'], self.input_dtype)
# Cast label data into int32
label = tf.cast(features['train/label'], tf.int64)
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
return image, label
def create_placeholders(self):
"""
Create placeholder data from features.
"""
image, label = self.__load_features()
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
self.data_ph = data_ph
self.label_ph = label_ph
Loading