From 319ef692a88827eeb510d3fa2cbf0218a5f22c5d Mon Sep 17 00:00:00 2001
From: Guillaume HEUSCH <guillaume.heusch@idiap.ch>
Date: Mon, 24 Jul 2017 11:59:26 +0200
Subject: [PATCH] [datashuffler] added the get_batch_epoch function in Memory
 datashuffler

---
 bob/learn/tensorflow/datashuffler/Memory.py | 80 +++++++++++++++++++++
 1 file changed, 80 insertions(+)

diff --git a/bob/learn/tensorflow/datashuffler/Memory.py b/bob/learn/tensorflow/datashuffler/Memory.py
index b0d2c898..d11dc578 100644
--- a/bob/learn/tensorflow/datashuffler/Memory.py
+++ b/bob/learn/tensorflow/datashuffler/Memory.py
@@ -62,6 +62,11 @@ class Memory(Base):
         # Seting the seed
         numpy.random.seed(seed)
         self.data = self.data.astype(input_dtype)
+        
+        # number of training examples as a 'list'
+        self.indexes = numpy.array(range(self.data.shape[0]))
+        # shuffle the indexes to get randomized mini-batches
+        numpy.random.shuffle(self.indexes)
 
     def get_batch(self):
         """
@@ -100,3 +105,78 @@ class Memory(Base):
         selected_data = self.normalize_sample(selected_data)
 
         return [selected_data.astype("float32"), selected_labels.astype("int64")]
+
+  
+    def get_batch_epoch(self):
+      """get_batch_epoch() -> selected_data, selected_labels
+
+      This function selects and returns data to be used in a minibatch iterations.
+      Note that it works in epochs, i.e. all the training data should be seen
+      during one epoch, which consists in several minibatch iterations.
+
+      **Returns**
+
+      selected_data:
+        Selected samples
+
+      selected_labels:
+        Correspondent labels
+      """
+      # this is done to rebuild the whole list (i.e. at the end of one epoch)
+      epoch_done = False
+
+      # returned mini-batch
+      selected_data = numpy.zeros(shape=self.shape)
+      selected_labels = [] 
+
+      # if there is not enough available data to fill the current mini-batch
+      # add randomly some examples THAT ARE NOT STILL PRESENT in the dataset !
+      if len(self.indexes) < self.batch_size:
+
+        print "should add examples to the current minibatch {0}".format(len(self.indexes))
+        # since we reached the end of an epoch, we'll have to reconsider all the data
+        epoch_done = True
+        number_of_examples_to_add = self.batch_size - len(self.indexes) 
+        added_examples = 0
+        
+        # generate a list of potential examples to add to this mini-batch
+        potential_indexes = numpy.array(range(self.data.shape[0]))
+        numpy.random.shuffle(potential_indexes)
+        
+        # add indexes that are not still present in the training data
+        for pot_index in potential_indexes:
+          if pot_index not in self.indexes:
+            self.indexes = numpy.append(self.indexes, [pot_index])
+            added_examples += 1
+            
+            # stop if we have enough examples
+            if added_examples == number_of_examples_to_add:
+              break
+      
+      # populate mini-batch
+      for i in range(self.batch_size):
+
+        current_index = self.batch_size - i - 1
+       
+        # get the data example
+        selected_data[i, ...] = self.data[self.indexes[current_index], ...]
+        
+        # normalization
+        selected_data[i, ...] = self.normalize_sample(selected_data[i, ...])
+        
+        # label
+        selected_labels.append(self.labels[self.indexes[current_index]])
+
+        # remove this example from the training set - used once in the epoch
+        new_indexes = numpy.delete(self.indexes, current_index)
+        self.indexes = new_indexes
+
+      if isinstance(selected_labels, list):
+        selected_labels = numpy.array(selected_labels)
+
+      # rebuild whole randomly shuffled training dataset
+      if epoch_done:
+        self.indexes = numpy.array(range(self.data.shape[0]))
+        numpy.random.shuffle(self.indexes)
+
+      return [selected_data.astype("float32"), selected_labels.astype("int64"), epoch_done]
-- 
GitLab