Commit b0e6d8db authored by Tiago Pereira's avatar Tiago Pereira

Using generators to speed up the batching #28

parent 5dd1debd
Pipeline #11298 failed with stages
in 23 minutes and 38 seconds
......@@ -216,3 +216,34 @@ class Base(object):
"""
return self.normalizer(x)
@staticmethod
def _aggregate_batch(data_holder, use_list=False):
size = len(data_holder[0])
result = []
for k in range(size):
if use_list:
result.append(
[x[k] for x in data_holder])
else:
dt = data_holder[0][k]
if type(dt) in [int, bool]:
tp = 'int32'
elif type(dt) == float:
tp = 'float32'
else:
try:
tp = dt.dtype
except:
raise TypeError("Unsupported type to batch: {}".format(type(dt)))
try:
result.append(
numpy.asarray([x[k] for x in data_holder], dtype=tp))
except KeyboardInterrupt:
raise
except:
#logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
return result
......@@ -70,6 +70,26 @@ class Memory(Base):
numpy.random.seed(seed)
self.data = self.data.astype(input_dtype)
def _fetch_batch(self):
# Shuffling samples
indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(indexes)
for i in range(self.batch_size):
sample = self.data[indexes[i], ...]
label = self.labels[indexes[i]]
if self.data_augmentation is not None:
sample = self.skimage2bob(sample)
sample = self.data_augmentation(sample)
sample = self.bob2skimage(sample)
if self.normalize_sample is not None:
sample = self.normalize_sample(sample)
yield [sample, label]
def get_batch(self):
"""
Shuffle the Memory dataset and get a random batch.
......@@ -82,20 +102,23 @@ class Memory(Base):
labels:
Correspondent labels
"""
# Shuffling samples
indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(indexes)
selected_data = self.data[indexes[0:self.batch_size], ...]
selected_labels = self.labels[indexes[0:self.batch_size]]
holder = []
for d in self._fetch_batch():
holder.append(d)
data, labels = self._aggregate_batch(holder, False)
# Applying the data augmentation
if self.data_augmentation is not None:
for i in range(selected_data.shape[0]):
img = self.skimage2bob(selected_data[i, ...])
img = self.data_augmentation(img)
selected_data[i, ...] = self.bob2skimage(img)
return data, labels
selected_data = self.normalize_sample(selected_data)
#selected_data = self.data[indexes[0:self.batch_size], ...]
#selected_labels = self.labels[indexes[0:self.batch_size]]
return [selected_data.astype("float32"), selected_labels.astype("int64")]
# Applying the data augmentation
#if self.data_augmentation is not None:
# for i in range(selected_data.shape[0]):
# img = self.skimage2bob(selected_data[i, ...])
# img = self.data_augmentation(img)
# selected_data[i, ...] = self.bob2skimage(img)
#selected_data = self.normalize_sample(selected_data)
#return [selected_data.astype("float32"), selected_labels.astype("int64")]
......@@ -339,6 +339,7 @@ class Trainer(object):
start = time.time()
self.fit(step)
end = time.time()
summary = summary_pb2.Summary.Value(tag="elapsed_time", simple_value=float(end-start))
self.train_summary_writter.add_summary(summary_pb2.Summary(value=[summary]), step)
......@@ -355,7 +356,6 @@ class Trainer(object):
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.saver.save(self.session, path, global_step=step)
#self.architecture.save(saver, path)
logger.info("Training finally finished")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment