Skip to content
Snippets Groups Projects

Gan

Closed Guillaume HEUSCH requested to merge gan into master
1 file
+ 69
0
Compare changes
  • Side-by-side
  • Inline
@@ -77,6 +77,11 @@ class Disk(Base):
# TODO: very bad solution to deal with bob.shape images an tf shape images
self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))
# number of training examples as a 'list'
self.indexes = numpy.array(range(self.data.shape[0]))
# shuffle the indexes to get randomized mini-batches
numpy.random.shuffle(self.indexes)
def load_from_file(self, file_name):
d = bob.io.base.load(file_name)
@@ -129,3 +134,67 @@ class Disk(Base):
selected_labels = self.labels[indexes[0:self.batch_size]]
return [selected_data.astype("float32"), selected_labels.astype("int64")]
def get_batch_epoch(self):
# this is done to rebuild the whole list (i.e. at the end of one epoch)
rebuild_indexes = False
# returned mini-batch
selected_data = numpy.zeros(shape=self.shape)
selected_labels = []
# if there is not enough available data to fill the current mini-batch
# add randomly some examples THAT ARE NOT STILL PRESENT in the dataset !
if len(self.indexes) < self.batch_size:
print "should add examples to the current minibatch {0}".format(len(self.indexes))
# since we reached the end of an epoch, we'll hace to reconsider all the data
rebuild_indexes = True
number_of_examples_to_add = self.batch_size - len(self.indexes)
added_examples = 0
# generate a list of potential examples to add to this mini-batch
potential_indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(potential_indexes)
# add indexes that are not still present in the training data
for pot_index in potential_indexes:
if pot_index not in self.indexes:
self.indexes = numpy.append(self.indexes, [pot_index])
added_examples += 1
# stop if we have enough examples
if added_examples == number_of_examples_to_add:
break
# populate mini-batch
for i in range(self.batch_size):
current_index = self.batch_size - i - 1
# get the data example
file_name = self.data[self.indexes[current_index]]
data = self.load_from_file(file_name)
selected_data[i, ...] = data
# normalization
selected_data[i, ...] = self.normalize_sample(selected_data[i, ...])
# label
selected_labels.append(self.labels[self.indexes[current_index]])
# remove this example from the training set - used once in the epoch
new_indexes = numpy.delete(self.indexes, current_index)
self.indexes = new_indexes
if isinstance(selected_labels, list):
selected_labels = numpy.array(selected_labels)
# rebuild whole randomly shuffled training dataset
if rebuild_indexes:
self.indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(self.indexes)
return [selected_data.astype("float32"), selected_labels.astype("int64")]
Loading