Commit 0b0055bf authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add a generic Generator class alternative to BioGenerator

parent c77e386e
import six
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
class Generator:
"""A generator class which wraps samples so that they can
be used with tf.data.Dataset.from_generator
Attributes
----------
epoch : int
The number of epochs that have been passed so far.
multiple_samples : :obj:`bool`, optional
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to for example treat
video databases as image databases.
reader : :obj:`object`, optional
A callable with the signature of ``data, label, key = reader(sample)``
which takes a sample and loads it.
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
output_types : (object, object, object)
The types of the returned samples.
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples.
"""
def __init__(self, samples, reader, multiple_samples=False, **kwargs):
super(Generator, self).__init__(**kwargs)
Please register or sign in to reply
self.reader = reader
self.samples = list(samples)
self.multiple_samples = multiple_samples
self.epoch = 0
# load one data to get its type and shape
dlk = self.reader(self.samples[0])
if self.multiple_samples:
try:
dlk = dlk[0]
except TypeError:
# if the data is a generator
dlk = six.next(dlk)
dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = dataset.output_types
self._output_shapes = dataset.output_shapes
logger.info(
"Initializing a dataset with %d %s and %s types and %s shapes",
len(self.samples),
"multi-samples" if self.multiple_samples else "samples",
self.output_types,
self.output_shapes,
)
@property
def output_types(self):
return self._output_types
@property
def output_shapes(self):
return self._output_shapes
def __call__(self):
"""A generator function that when called will yield the samples.
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
for sample in self.samples:
dlk = self.reader(sample)
if self.multiple_samples:
for sub_dlk in dlk:
yield sub_dlk
else:
yield dlk
self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch)
def dataset_using_generator(*args, **kwargs):
generator = Generator(*args, **kwargs)
dataset = tf.data.Dataset.from_generator(
generator, generator.output_types, generator.output_shapes
)
return dataset
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment