Skip to content
Snippets Groups Projects
Commit ccac6262 authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[datashuffler] Label can be float an of arbitrary shape

parent 74531ca4
Branches
No related tags found
No related merge requests found
#!/usr/bin/env python #!/usr/bin/env python
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST # @date: Wed 11 May 2016 09:39:36 CEST
import numpy import numpy
import tensorflow as tf import tensorflow as tf
...@@ -39,18 +39,20 @@ class Base(object): ...@@ -39,18 +39,20 @@ class Base(object):
normalizer: normalizer:
The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset` The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset`
prefetch: prefetch:
Do prefetch? Do prefetch?
prefetch_capacity: prefetch_capacity:
""" """
def __init__(self, data, labels, def __init__(self, data, labels,
input_shape=[None, 28, 28, 1], input_shape=[None, 28, 28, 1],
input_dtype="float32", input_dtype="float32",
input_lshape=[None],
input_ltype=tf.int64,
batch_size=32, batch_size=32,
seed=10, seed=10,
data_augmentation=None, data_augmentation=None,
...@@ -65,6 +67,7 @@ class Base(object): ...@@ -65,6 +67,7 @@ class Base(object):
self.normalizer = normalizer self.normalizer = normalizer
self.input_dtype = input_dtype self.input_dtype = input_dtype
self.input_ltype = input_ltype
# TODO: Check if the bacth size is higher than the input data # TODO: Check if the bacth size is higher than the input data
self.batch_size = batch_size self.batch_size = batch_size
...@@ -74,6 +77,8 @@ class Base(object): ...@@ -74,6 +77,8 @@ class Base(object):
self.input_shape = tuple(input_shape) self.input_shape = tuple(input_shape)
self.labels = labels self.labels = labels
self.possible_labels = list(set(self.labels)) self.possible_labels = list(set(self.labels))
self.input_lshape = tuple(input_lshape)
# Computing the data samples fro train and validation # Computing the data samples fro train and validation
self.n_samples = len(self.labels) self.n_samples = len(self.labels)
...@@ -101,13 +106,17 @@ class Base(object): ...@@ -101,13 +106,17 @@ class Base(object):
def create_placeholders(self): def create_placeholders(self):
""" """
Create place holder instances Create place holder instances
:return: :return:
""" """
with tf.name_scope("Input"): with tf.name_scope("Input"):
self.data_ph = tf.placeholder(tf.float32, shape=self.input_shape, name="data") self.data_ph = tf.placeholder(self.input_dtype,
self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label") shape=self.input_shape,
name="data")
self.label_ph = tf.placeholder(self.input_ltype,
shape=self.input_lshape,
name="label")
# If prefetch, setup the queue to feed data # If prefetch, setup the queue to feed data
if self.prefetch: if self.prefetch:
...@@ -126,7 +135,7 @@ class Base(object): ...@@ -126,7 +135,7 @@ class Base(object):
def __call__(self, element, from_queue=False): def __call__(self, element, from_queue=False):
""" """
Return the necessary placeholder Return the necessary placeholder
""" """
if not element in ["data", "label"]: if not element in ["data", "label"]:
...@@ -264,7 +273,7 @@ class Base(object): ...@@ -264,7 +273,7 @@ class Base(object):
try: try:
for i in range(self.batch_size): for i in range(self.batch_size):
data = self.batch_generator.next() data = self.batch_generator.next()
holder.append(data) holder.append(data)
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
return self._aggregate_batch(holder, False) return self._aggregate_batch(holder, False)
...@@ -272,7 +281,7 @@ class Base(object): ...@@ -272,7 +281,7 @@ class Base(object):
except StopIteration: except StopIteration:
self.batch_generator = None self.batch_generator = None
self.epoch += 1 self.epoch += 1
# If we have left data in the epoch, return # If we have left data in the epoch, return
if len(holder) > 0: if len(holder) > 0:
return self._aggregate_batch(holder, False) return self._aggregate_batch(holder, False)
...@@ -281,4 +290,3 @@ class Base(object): ...@@ -281,4 +290,3 @@ class Base(object):
data = self.batch_generator.next() data = self.batch_generator.next()
holder.append(data) holder.append(data)
return self._aggregate_batch(holder, False) return self._aggregate_batch(holder, False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment