From 70e0b8513bf343cdc2d4ed64063c090ff17795f9 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 28 Jun 2019 11:50:55 +0200 Subject: [PATCH] Allow shuffle on epoch end in generator --- bob/learn/tensorflow/dataset/generator.py | 39 +++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/bob/learn/tensorflow/dataset/generator.py b/bob/learn/tensorflow/dataset/generator.py index fd157614..32d797b7 100644 --- a/bob/learn/tensorflow/dataset/generator.py +++ b/bob/learn/tensorflow/dataset/generator.py @@ -1,5 +1,6 @@ import six import tensorflow as tf +import random import logging logger = logging.getLogger(__name__) @@ -22,18 +23,21 @@ class Generator: which takes a sample and loads it. samples : [:obj:`object`] A list of samples to be given to ``reader`` to load the data. + shuffle_on_epoch_end : :obj:`bool`, optional + If True, it shuffle the samples at the end of each epoch. output_types : (object, object, object) The types of the returned samples. output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)`` The shapes of the returned samples. """ - def __init__(self, samples, reader, multiple_samples=False, **kwargs): + def __init__(self, samples, reader, multiple_samples=False, shuffle_on_epoch_end=False, **kwargs): super().__init__(**kwargs) self.reader = reader self.samples = list(samples) self.multiple_samples = multiple_samples self.epoch = 0 + self.shuffle_on_epoch_end = shuffle_on_epoch_end # load one data to get its type and shape dlk = self.reader(self.samples[0]) @@ -81,31 +85,34 @@ class Generator: yield dlk self.epoch += 1 logger.info("Elapsed %d epoch(s)", self.epoch) + if self.shuffle_on_epoch_end: + logger.info("Shuffling samples") + random.shuffle(self.samples) -def dataset_using_generator(*args, **kwargs): +def dataset_using_generator(samples, reader, **kwargs): """ A generator class which wraps samples so that they can be used with tf.data.Dataset.from_generator - Attributes + Parameters ---------- + samples : [:obj:`object`] + A list of samples to be given to ``reader`` to load the data. - samples : [:obj:`object`] - A list of samples to be given to ``reader`` to load the data. - - reader : :obj:`object`, optional - A callable with the signature of ``data, label, key = reader(sample)`` - which takes a sample and loads it. - - multiple_samples : :obj:`bool`, optional - If true, it assumes that the bio database's samples actually contain - multiple samples. This is useful for when you want to for example treat - video databases as image databases. - + reader : :obj:`object`, optional + A callable with the signature of ``data, label, key = reader(sample)`` + which takes a sample and loads it. + **kwargs + Extra keyword arguments are passed to Generator + + Returns + ------- + object + A tf.data.Dataset """ - generator = Generator(*args, **kwargs) + generator = Generator(samples, reader, **kwargs) dataset = tf.data.Dataset.from_generator( generator, generator.output_types, generator.output_shapes ) -- GitLab