Commit 70e0b851 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Allow shuffle on epoch end in generator

parent 6aabd230
import six
import tensorflow as tf
import random
import logging
logger = logging.getLogger(__name__)
......@@ -22,18 +23,21 @@ class Generator:
which takes a sample and loads it.
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
shuffle_on_epoch_end : :obj:`bool`, optional
If True, it shuffle the samples at the end of each epoch.
output_types : (object, object, object)
The types of the returned samples.
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples.
"""
def __init__(self, samples, reader, multiple_samples=False, **kwargs):
def __init__(self, samples, reader, multiple_samples=False, shuffle_on_epoch_end=False, **kwargs):
super().__init__(**kwargs)
self.reader = reader
self.samples = list(samples)
self.multiple_samples = multiple_samples
self.epoch = 0
self.shuffle_on_epoch_end = shuffle_on_epoch_end
# load one data to get its type and shape
dlk = self.reader(self.samples[0])
......@@ -81,31 +85,34 @@ class Generator:
yield dlk
self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch)
if self.shuffle_on_epoch_end:
logger.info("Shuffling samples")
random.shuffle(self.samples)
def dataset_using_generator(*args, **kwargs):
def dataset_using_generator(samples, reader, **kwargs):
"""
A generator class which wraps samples so that they can
be used with tf.data.Dataset.from_generator
Attributes
Parameters
----------
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
reader : :obj:`object`, optional
A callable with the signature of ``data, label, key = reader(sample)``
which takes a sample and loads it.
multiple_samples : :obj:`bool`, optional
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to for example treat
video databases as image databases.
reader : :obj:`object`, optional
A callable with the signature of ``data, label, key = reader(sample)``
which takes a sample and loads it.
**kwargs
Extra keyword arguments are passed to Generator
Returns
-------
object
A tf.data.Dataset
"""
generator = Generator(*args, **kwargs)
generator = Generator(samples, reader, **kwargs)
dataset = tf.data.Dataset.from_generator(
generator, generator.output_types, generator.output_shapes
)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment