Commit 617bd939 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Disk datashuffler with generators

parent 087fd7a9
......@@ -95,10 +95,9 @@ class Base(object):
self.data_ph_from_queue = None
self.label_ph_from_queue = None
self.generator = None
self.batch_generator = None
self.epoch = 0
def create_placeholders(self):
"""
Create place holder instances
......@@ -149,11 +148,6 @@ class Base(object):
else:
return self.label_ph
def get_batch(self):
"""
Shuffle dataset and get a random batch.
"""
raise NotImplementedError("Method not implemented in this level. You should use one of the derived classes.")
def bob2skimage(self, bob_image):
"""
......@@ -221,8 +215,7 @@ class Base(object):
return self.normalizer(x)
@staticmethod
def _aggregate_batch(data_holder, use_list=False):
def _aggregate_batch(self, data_holder, use_list=False):
size = len(data_holder[0])
result = []
for k in range(size):
......@@ -232,9 +225,9 @@ class Base(object):
else:
dt = data_holder[0][k]
if type(dt) in [int, bool]:
tp = 'int32'
tp = 'int64'
elif type(dt) == float:
tp = 'float32'
tp = self.input_dtype
else:
try:
tp = dt.dtype
......@@ -251,3 +244,31 @@ class Base(object):
IP.embed(config=IP.terminal.ipapp.load_default_config())
return result
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
** Returns **
data:
Selected samples
labels:
Correspondent labels
"""
if self.batch_generator is None:
self.batch_generator = self._fetch_batch()
holder = []
try:
for i in range(self.batch_size):
data = self.batch_generator.next()
holder.append(data)
if len(holder) == self.batch_size:
return self._aggregate_batch(holder, False)
except StopIteration:
self.batch_generator = None
self.epoch += 1
return self._aggregate_batch(holder, False)
......@@ -105,7 +105,7 @@ class Disk(Base):
return data
def get_batch(self):
def _fetch_batch(self):
"""
Shuffle the Disk dataset, get a random batch and load it on the fly.
......@@ -118,23 +118,25 @@ class Disk(Base):
Correspondent labels
"""
shape = [self.batch_size] + list(self.input_shape[1:])
# Shuffling samples
indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(indexes)
selected_data = numpy.zeros(shape=shape)
for i in range(self.batch_size):
#selected_data = numpy.zeros(shape=shape)
for i in indexes:
file_name = self.data[indexes[i]]
file_name = self.data[i]
data = self.load_from_file(file_name)
selected_data[i, ...] = data
if self.data_augmentation is not None:
data = self.skimage2bob(data)
data = self.data_augmentation(data)
data = self.bob2skimage(data)
# Scaling
selected_data[i, ...] = self.normalize_sample(selected_data[i, ...])
if self.normalize_sample is not None:
data = self.normalize_sample(data)
selected_labels = self.labels[indexes[0:self.batch_size]]
data = data.astype(self.input_dtype)
label = self.labels[i]
return [selected_data.astype("float32"), selected_labels.astype("int64")]
yield [data, label]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment