Commit caac0743 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added tests for the Generator dataset

parent 01ac7f66
...@@ -29,7 +29,7 @@ class Generator: ...@@ -29,7 +29,7 @@ class Generator:
""" """
def __init__(self, samples, reader, multiple_samples=False, **kwargs): def __init__(self, samples, reader, multiple_samples=False, **kwargs):
super(Generator, self).__init__(**kwargs) super().__init__(**kwargs)
self.reader = reader self.reader = reader
self.samples = list(samples) self.samples = list(samples)
self.multiple_samples = multiple_samples self.multiple_samples = multiple_samples
...@@ -43,6 +43,7 @@ class Generator: ...@@ -43,6 +43,7 @@ class Generator:
except TypeError: except TypeError:
# if the data is a generator # if the data is a generator
dlk = six.next(dlk) dlk = six.next(dlk)
# Creating a "fake" dataset just to get the types and shapes
dataset = tf.data.Dataset.from_tensors(dlk) dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = dataset.output_types self._output_types = dataset.output_types
self._output_shapes = dataset.output_shapes self._output_shapes = dataset.output_shapes
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import pkg_resources import pkg_resources
import numpy
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch
from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch from bob.learn.tensorflow.dataset.triplet_image import shuffle_data_and_labels_image_augmentation as triplet_batch
from bob.learn.tensorflow.dataset.generator import dataset_using_generator
data_shape = (250, 250, 3) data_shape = (250, 250, 3)
output_shape = (50, 50) output_shape = (50, 50)
...@@ -76,3 +78,27 @@ def test_triplet_dataset(): ...@@ -76,3 +78,27 @@ def test_triplet_dataset():
assert d['anchor'].shape == (2, 50, 50, 3) assert d['anchor'].shape == (2, 50, 50, 3)
assert d['positive'].shape == (2, 50, 50, 3) assert d['positive'].shape == (2, 50, 50, 3)
assert d['negative'].shape == (2, 50, 50, 3) assert d['negative'].shape == (2, 50, 50, 3)
def test_dataset_using_generator():
def reader(f):
key = 0
label = 0
yield {'data': f, 'key': key}, label
shape = (2, 2, 1)
samples = [numpy.ones(shape, dtype="float32")*i for i in range(10)]
with tf.Session() as session:
dataset = dataset_using_generator(samples,\
reader,\
multiple_samples=True)
iterator = dataset.make_one_shot_iterator().get_next()
while True:
try:
sample = session.run(iterator)
assert sample[0]["data"].shape == shape
except tf.errors.OutOfRangeError:
break
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment