Skip to content
Snippets Groups Projects

Gan

Closed Guillaume HEUSCH requested to merge gan into master
1 file
+ 7
5
Compare changes
  • Side-by-side
  • Inline
@@ -139,7 +139,7 @@ class Disk(Base):
def get_batch_epoch(self):
# this is done to rebuild the whole list (i.e. at the end of one epoch)
rebuild_indexes = False
epoch_done = False
# returned mini-batch
selected_data = numpy.zeros(shape=self.shape)
@@ -151,7 +151,7 @@ class Disk(Base):
print "should add examples to the current minibatch {0}".format(len(self.indexes))
# since we reached the end of an epoch, we'll hace to reconsider all the data
rebuild_indexes = True
epoch_done = True
number_of_examples_to_add = self.batch_size - len(self.indexes)
added_examples = 0
@@ -173,7 +173,9 @@ class Disk(Base):
for i in range(self.batch_size):
current_index = self.batch_size - i - 1
# TODO: try/catch for file loading
# get the data example
file_name = self.data[self.indexes[current_index]]
data = self.load_from_file(file_name)
@@ -193,8 +195,8 @@ class Disk(Base):
selected_labels = numpy.array(selected_labels)
# rebuild whole randomly shuffled training dataset
if rebuild_indexes:
if epoch_done:
self.indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(self.indexes)
return [selected_data.astype("float32"), selected_labels.astype("int64")]
return [selected_data.astype("float32"), selected_labels.astype("int64"), epoch_done]
Loading