Skip to content
Snippets Groups Projects

Updates

Merged Tiago de Freitas Pereira requested to merge updates into master
2 files
+ 16
65
Compare changes
  • Side-by-side
  • Inline
Files
2
#!/usr/bin/env python
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import numpy
import numpy
import tensorflow as tf
import tensorflow as tf
@@ -11,24 +10,26 @@ from bob.learn.tensorflow.datashuffler.Normalizer import Linear
@@ -11,24 +10,26 @@ from bob.learn.tensorflow.datashuffler.Normalizer import Linear
class TFRecord(object):
class TFRecord(object):
"""
"""
Datashuffler that wraps tfrecords
Generic datashuffler that wraps the batching using tfrecords.
 
 
**Parameters**
 
 
filename_queue: Tensorflow producer
 
input_shape: Shape of the input in the tfrecord
 
input_dtype: Type of the raw data saved in the tf record
 
batch_size: Size of the batch
 
seed: Seed
 
prefetch_capacity: Capacity of the bucket for prefetching
 
prefetch_threads: Number of threads in the prefetching
"""
"""
def __init__(self,filename_queue,
def __init__(self,filename_queue,
input_shape=[None, 28, 28, 1],
input_shape=[None, 28, 28, 1],
output_shape=[None, 28, 28, 1],
input_dtype=tf.float32,
input_dtype=tf.float32,
batch_size=32,
batch_size=32,
seed=10,
seed=10,
prefetch_capacity=50,
prefetch_capacity=50,
prefetch_threads=5,
prefetch_threads=5):
shuffle=True,
normalization=False,
random_flip=True,
random_crop=True
):
# Setting the seed for the pseudo random number generator
# Setting the seed for the pseudo random number generator
self.seed = seed
self.seed = seed
@@ -42,7 +43,6 @@ class TFRecord(object):
@@ -42,7 +43,6 @@ class TFRecord(object):
# Preparing the inputs
# Preparing the inputs
self.filename_queue = filename_queue
self.filename_queue = filename_queue
self.input_shape = tuple(input_shape)
self.input_shape = tuple(input_shape)
self.output_shape = output_shape
# Prefetch variables
# Prefetch variables
self.prefetch = True
self.prefetch = True
@@ -51,10 +51,6 @@ class TFRecord(object):
@@ -51,10 +51,6 @@ class TFRecord(object):
self.data_ph = None
self.data_ph = None
self.label_ph = None
self.label_ph = None
self.shuffle = shuffle
self.normalization = normalization
self.random_crop = random_crop
def __call__(self, element, from_queue=False):
def __call__(self, element, from_queue=False):
"""
"""
@@ -91,33 +87,17 @@ class TFRecord(object):
@@ -91,33 +87,17 @@ class TFRecord(object):
# Convert the image data from string back to the numbers
# Convert the image data from string back to the numbers
image = tf.decode_raw(features['train/data'], self.input_dtype)
image = tf.decode_raw(features['train/data'], self.input_dtype)
#image = tf.decode_raw(features['train/data'], tf.uint8)
# Cast label data into int32
# Cast label data into int32
label = tf.cast(features['train/label'], tf.int64)
label = tf.cast(features['train/label'], tf.int64)
# Reshape image data into the original shape
# Reshape image data into the original shape
image = tf.reshape(image, self.input_shape[1:])
image = tf.reshape(image, self.input_shape[1:])
# Casting to float32
image = tf.cast(image, tf.float32)
if self.random_crop:
image = tf.image.resize_image_with_crop_or_pad(image, self.output_shape[1], self.output_shape[2])
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
# normalizing data
min_after_dequeue=1, name="shuffle_batch")
if self.normalization:
image = tf.image.per_image_standardization(image)
image.set_shape(tuple(self.output_shape[1:]))
if self.shuffle:
data_ph, label_ph = tf.train.shuffle_batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads,
min_after_dequeue=1, name="shuffle_batch")
else:
data_ph, label_ph = tf.train.batch([image, label], batch_size=self.batch_size,
capacity=self.prefetch_capacity, num_threads=self.prefetch_threads, name="batch")
self.data_ph = data_ph
self.data_ph = data_ph
@@ -125,17 +105,5 @@ class TFRecord(object):
@@ -125,17 +105,5 @@ class TFRecord(object):
def get_batch(self):
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
** Returns **
data:
Selected samples
labels:
Correspondent labels
"""
pass
pass
Loading