Commit 70e0b851 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Allow shuffle on epoch end in generator

parent 6aabd230
import six import six
import tensorflow as tf import tensorflow as tf
import random
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -22,18 +23,21 @@ class Generator: ...@@ -22,18 +23,21 @@ class Generator:
which takes a sample and loads it. which takes a sample and loads it.
samples : [:obj:`object`] samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data. A list of samples to be given to ``reader`` to load the data.
shuffle_on_epoch_end : :obj:`bool`, optional
If True, it shuffle the samples at the end of each epoch.
output_types : (object, object, object) output_types : (object, object, object)
The types of the returned samples. The types of the returned samples.
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)`` output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples. The shapes of the returned samples.
""" """
def __init__(self, samples, reader, multiple_samples=False, **kwargs): def __init__(self, samples, reader, multiple_samples=False, shuffle_on_epoch_end=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.reader = reader self.reader = reader
self.samples = list(samples) self.samples = list(samples)
self.multiple_samples = multiple_samples self.multiple_samples = multiple_samples
self.epoch = 0 self.epoch = 0
self.shuffle_on_epoch_end = shuffle_on_epoch_end
# load one data to get its type and shape # load one data to get its type and shape
dlk = self.reader(self.samples[0]) dlk = self.reader(self.samples[0])
...@@ -81,31 +85,34 @@ class Generator: ...@@ -81,31 +85,34 @@ class Generator:
yield dlk yield dlk
self.epoch += 1 self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch) logger.info("Elapsed %d epoch(s)", self.epoch)
if self.shuffle_on_epoch_end:
logger.info("Shuffling samples")
random.shuffle(self.samples)
def dataset_using_generator(*args, **kwargs): def dataset_using_generator(samples, reader, **kwargs):
""" """
A generator class which wraps samples so that they can A generator class which wraps samples so that they can
be used with tf.data.Dataset.from_generator be used with tf.data.Dataset.from_generator
Attributes Parameters
---------- ----------
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
samples : [:obj:`object`] reader : :obj:`object`, optional
A list of samples to be given to ``reader`` to load the data. A callable with the signature of ``data, label, key = reader(sample)``
which takes a sample and loads it.
reader : :obj:`object`, optional **kwargs
A callable with the signature of ``data, label, key = reader(sample)`` Extra keyword arguments are passed to Generator
which takes a sample and loads it.
Returns
multiple_samples : :obj:`bool`, optional -------
If true, it assumes that the bio database's samples actually contain object
multiple samples. This is useful for when you want to for example treat A tf.data.Dataset
video databases as image databases.
""" """
generator = Generator(*args, **kwargs) generator = Generator(samples, reader, **kwargs)
dataset = tf.data.Dataset.from_generator( dataset = tf.data.Dataset.from_generator(
generator, generator.output_types, generator.output_shapes generator, generator.output_types, generator.output_shapes
) )
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment