Reorganized the the TFRecord datashufflers

parent 67c8a148
Pipeline #13007 passed with stages
in 16 minutes and 12 seconds
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
import tensorflow as tf
......@@ -11,24 +10,26 @@ from bob.learn.tensorflow.datashuffler.Normalizer import Linear
class TFRecord(object):
"""
Datashuffler that wraps tfrecords
Generic datashuffler that wraps the batching using tfrecords.
**Parameters**
filename_queue: Tensorflow producer
input_shape: Shape of the input in the tfrecord
input_dtype: Type of the raw data saved in the tf record
batch_size: Size of the batch
seed: Seed
prefetch_capacity: Capacity of the bucket for prefetching
prefetch_threads: Number of threads in the prefetching
"""
def __init__(self,filename_queue,
input_shape=[None, 28, 28, 1],
output_shape=[None, 28, 28, 1],
input_dtype=tf.float32,
batch_size=32,
seed=10,
prefetch_capacity=50,
prefetch_threads=5,
shuffle=True,
normalization=False,
random_flip=True,
random_crop=True
):
prefetch_threads=5):
# Setting the seed for the pseudo random number generator
self.seed = seed
......@@ -42,7 +43,6 @@ class TFRecord(object):
# Preparing the inputs
self.filename_queue = filename_queue
self.input_shape = tuple(input_shape)
self.output_shape = output_shape
# Prefetch variables
self.prefetch = True
......@@ -51,10 +51,6 @@ class TFRecord(object):
self.data_ph = None
self.label_ph = None
self.shuffle = shuffle
self.normalization = normalization
self.random_crop = random_crop
def __call__(self, element, from_queue=False):
"""
......@@ -91,33 +87,17 @@ class TFRecord(object):
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/data'], self.input_dtype)
#image = tf.decode_raw(features['train/data'], tf.uint8)
# Cast label data into int32
label = tf.cast(features['train/label'], tf.int64)
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
# Casting to float32
image = tf.cast(image, tf.float32)
if self.random_crop:
image = tf.image.resize_image_with_crop_or_pad(image, self.output_shape[1], self.output_shape[2])
# normalizing data
if self.normalization:
image = tf.image.per_image_standardization(image)
image.set_shape(tuple(self.output_shape[1:]))
if self.shuffle:
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
else:
data_ph, label_ph = tf.train.batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch")
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
self.data_ph = data_ph
......@@ -125,17 +105,5 @@ class TFRecord(object):
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
** Returns **
data:
Selected samples
labels:
Correspondent labels
"""
pass
......@@ -10,7 +10,6 @@ import numpy
from bob.learn.tensorflow.datashuffler.Normalizer import Linear
from .TFRecord import TFRecord
class TFRecordImage(TFRecord):
"""
Datashuffler that wraps the batching using tfrecords.
......@@ -148,19 +147,3 @@ class TFRecordImage(TFRecord):
self.data_ph = data_ph
self.label_ph = label_ph
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
** Returns **
data:
Selected samples
labels:
Correspondent labels
"""
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment