From 319ef692a88827eeb510d3fa2cbf0218a5f22c5d Mon Sep 17 00:00:00 2001 From: Guillaume HEUSCH <guillaume.heusch@idiap.ch> Date: Mon, 24 Jul 2017 11:59:26 +0200 Subject: [PATCH] [datashuffler] added the get_batch_epoch function in Memory datashuffler --- bob/learn/tensorflow/datashuffler/Memory.py | 80 +++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/bob/learn/tensorflow/datashuffler/Memory.py b/bob/learn/tensorflow/datashuffler/Memory.py index b0d2c898..d11dc578 100644 --- a/bob/learn/tensorflow/datashuffler/Memory.py +++ b/bob/learn/tensorflow/datashuffler/Memory.py @@ -62,6 +62,11 @@ class Memory(Base): # Seting the seed numpy.random.seed(seed) self.data = self.data.astype(input_dtype) + + # number of training examples as a 'list' + self.indexes = numpy.array(range(self.data.shape[0])) + # shuffle the indexes to get randomized mini-batches + numpy.random.shuffle(self.indexes) def get_batch(self): """ @@ -100,3 +105,78 @@ class Memory(Base): selected_data = self.normalize_sample(selected_data) return [selected_data.astype("float32"), selected_labels.astype("int64")] + + + def get_batch_epoch(self): + """get_batch_epoch() -> selected_data, selected_labels + + This function selects and returns data to be used in a minibatch iterations. + Note that it works in epochs, i.e. all the training data should be seen + during one epoch, which consists in several minibatch iterations. + + **Returns** + + selected_data: + Selected samples + + selected_labels: + Correspondent labels + """ + # this is done to rebuild the whole list (i.e. at the end of one epoch) + epoch_done = False + + # returned mini-batch + selected_data = numpy.zeros(shape=self.shape) + selected_labels = [] + + # if there is not enough available data to fill the current mini-batch + # add randomly some examples THAT ARE NOT STILL PRESENT in the dataset ! + if len(self.indexes) < self.batch_size: + + print "should add examples to the current minibatch {0}".format(len(self.indexes)) + # since we reached the end of an epoch, we'll have to reconsider all the data + epoch_done = True + number_of_examples_to_add = self.batch_size - len(self.indexes) + added_examples = 0 + + # generate a list of potential examples to add to this mini-batch + potential_indexes = numpy.array(range(self.data.shape[0])) + numpy.random.shuffle(potential_indexes) + + # add indexes that are not still present in the training data + for pot_index in potential_indexes: + if pot_index not in self.indexes: + self.indexes = numpy.append(self.indexes, [pot_index]) + added_examples += 1 + + # stop if we have enough examples + if added_examples == number_of_examples_to_add: + break + + # populate mini-batch + for i in range(self.batch_size): + + current_index = self.batch_size - i - 1 + + # get the data example + selected_data[i, ...] = self.data[self.indexes[current_index], ...] + + # normalization + selected_data[i, ...] = self.normalize_sample(selected_data[i, ...]) + + # label + selected_labels.append(self.labels[self.indexes[current_index]]) + + # remove this example from the training set - used once in the epoch + new_indexes = numpy.delete(self.indexes, current_index) + self.indexes = new_indexes + + if isinstance(selected_labels, list): + selected_labels = numpy.array(selected_labels) + + # rebuild whole randomly shuffled training dataset + if epoch_done: + self.indexes = numpy.array(range(self.data.shape[0])) + numpy.random.shuffle(self.indexes) + + return [selected_data.astype("float32"), selected_labels.astype("int64"), epoch_done] -- GitLab