From 70e0b8513bf343cdc2d4ed64063c090ff17795f9 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 28 Jun 2019 11:50:55 +0200
Subject: [PATCH] Allow shuffle on epoch end in generator

---
 bob/learn/tensorflow/dataset/generator.py | 39 +++++++++++++----------
 1 file changed, 23 insertions(+), 16 deletions(-)

diff --git a/bob/learn/tensorflow/dataset/generator.py b/bob/learn/tensorflow/dataset/generator.py
index fd157614..32d797b7 100644
--- a/bob/learn/tensorflow/dataset/generator.py
+++ b/bob/learn/tensorflow/dataset/generator.py
@@ -1,5 +1,6 @@
 import six
 import tensorflow as tf
+import random
 import logging
 
 logger = logging.getLogger(__name__)
@@ -22,18 +23,21 @@ class Generator:
         which takes a sample and loads it.
     samples : [:obj:`object`]
         A list of samples to be given to ``reader`` to load the data.
+    shuffle_on_epoch_end : :obj:`bool`, optional
+        If True, it shuffle the samples at the end of each epoch.
     output_types : (object, object, object)
         The types of the returned samples.
     output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
         The shapes of the returned samples.
     """
 
-    def __init__(self, samples, reader, multiple_samples=False, **kwargs):
+    def __init__(self, samples, reader, multiple_samples=False, shuffle_on_epoch_end=False, **kwargs):
         super().__init__(**kwargs)
         self.reader = reader
         self.samples = list(samples)
         self.multiple_samples = multiple_samples
         self.epoch = 0
+        self.shuffle_on_epoch_end = shuffle_on_epoch_end
 
         # load one data to get its type and shape
         dlk = self.reader(self.samples[0])
@@ -81,31 +85,34 @@ class Generator:
                 yield dlk
         self.epoch += 1
         logger.info("Elapsed %d epoch(s)", self.epoch)
+        if self.shuffle_on_epoch_end:
+            logger.info("Shuffling samples")
+            random.shuffle(self.samples)
 
 
-def dataset_using_generator(*args, **kwargs):
+def dataset_using_generator(samples, reader, **kwargs):
     """
     A generator class which wraps samples so that they can
     be used with tf.data.Dataset.from_generator
 
-    Attributes
+    Parameters
     ----------
+    samples : [:obj:`object`]
+       A list of samples to be given to ``reader`` to load the data.
 
-     samples : [:obj:`object`]
-        A list of samples to be given to ``reader`` to load the data.
-
-     reader : :obj:`object`, optional
-        A callable with the signature of ``data, label, key = reader(sample)``
-        which takes a sample and loads it.
-
-     multiple_samples : :obj:`bool`, optional
-        If true, it assumes that the bio database's samples actually contain
-        multiple samples. This is useful for when you want to for example treat
-        video databases as image databases.
-     
+    reader : :obj:`object`, optional
+       A callable with the signature of ``data, label, key = reader(sample)``
+       which takes a sample and loads it.
+    **kwargs
+        Extra keyword arguments are passed to Generator
+
+    Returns
+    -------
+    object
+        A tf.data.Dataset
     """
 
-    generator = Generator(*args, **kwargs)
+    generator = Generator(samples, reader, **kwargs)
     dataset = tf.data.Dataset.from_generator(
         generator, generator.output_types, generator.output_shapes
     )
-- 
GitLab