Skip to content
Snippets Groups Projects
Commit 70e0b851 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Allow shuffle on epoch end in generator

parent 6aabd230
No related branches found
No related tags found
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
import six
import tensorflow as tf
import random
import logging
logger = logging.getLogger(__name__)
......@@ -22,18 +23,21 @@ class Generator:
which takes a sample and loads it.
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
shuffle_on_epoch_end : :obj:`bool`, optional
If True, it shuffle the samples at the end of each epoch.
output_types : (object, object, object)
The types of the returned samples.
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples.
"""
def __init__(self, samples, reader, multiple_samples=False, **kwargs):
def __init__(self, samples, reader, multiple_samples=False, shuffle_on_epoch_end=False, **kwargs):
super().__init__(**kwargs)
self.reader = reader
self.samples = list(samples)
self.multiple_samples = multiple_samples
self.epoch = 0
self.shuffle_on_epoch_end = shuffle_on_epoch_end
# load one data to get its type and shape
dlk = self.reader(self.samples[0])
......@@ -81,31 +85,34 @@ class Generator:
yield dlk
self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch)
if self.shuffle_on_epoch_end:
logger.info("Shuffling samples")
random.shuffle(self.samples)
def dataset_using_generator(*args, **kwargs):
def dataset_using_generator(samples, reader, **kwargs):
"""
A generator class which wraps samples so that they can
be used with tf.data.Dataset.from_generator
Attributes
Parameters
----------
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
samples : [:obj:`object`]
A list of samples to be given to ``reader`` to load the data.
reader : :obj:`object`, optional
A callable with the signature of ``data, label, key = reader(sample)``
which takes a sample and loads it.
multiple_samples : :obj:`bool`, optional
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to for example treat
video databases as image databases.
reader : :obj:`object`, optional
A callable with the signature of ``data, label, key = reader(sample)``
which takes a sample and loads it.
**kwargs
Extra keyword arguments are passed to Generator
Returns
-------
object
A tf.data.Dataset
"""
generator = Generator(*args, **kwargs)
generator = Generator(samples, reader, **kwargs)
dataset = tf.data.Dataset.from_generator(
generator, generator.output_types, generator.output_shapes
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment