Skip to content
Snippets Groups Projects
Commit 2efd0e68 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Refactoring the train

parent f4919f82
Branches
Tags
No related merge requests found
......@@ -7,16 +7,15 @@
Neural net work error rates analizer
"""
import numpy
import bob.measure
from tensorflow.core.framework import summary_pb2
class SoftmaxAnalizer:
class SoftmaxAnalizer(object):
"""
Analizer.
"""
def __init__(self, data_shuffler, machine, session):
def __init__(self):
"""
Softmax analizer
......@@ -31,43 +30,25 @@ class SoftmaxAnalizer:
"""
self.data_shuffler = data_shuffler
self.machine = machine
self.session = session
self.data_shuffler = None
self.machine = None
self.session = None
def __call__(self, data_shuffler, machine, session):
"""
placeholder_data, placeholder_labels = data_shuffler.get_placeholders(name="validation")
graph = machine.compute_graph(placeholder_data)
loss_validation = self.loss(validation_graph, validation_placeholder_labels)
tf.scalar_summary('accuracy', loss_validation, name="validation")
merged_validation = tf.merge_all_summaries()
"""
def __call__(self, graph=None):
validation_graph = self.compute_graph(self.data_shuffler, name="validation")
predictions = numpy.argmax(self.session.run(network,
feed_dict={data_node: data[:]}), 1)
return 100. * numpy.sum(predictions == labels) / predictions.shape[0]
if self.data_shuffler is None:
self.data_shuffler = data_shuffler
self.machine = machine
self.session = session
# Creating the graph
feature_batch, label_batch = self.data_shuffler.get_placeholders(name="validation_accuracy")
data, labels = self.data_shuffler.get_batch()
graph = self.machine.compute_graph(feature_batch)
feed_dict = {validation_placeholder_data: validation_data,
validation_placeholder_labels: validation_labels}
predictions = numpy.argmax(self.session.run(graph, feed_dict={feature_batch: data[:]}), 1)
accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
# l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
# l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
# import ipdb; ipdb.set_trace();
l = session.run(loss_validation, feed_dict=feed_dict)
summaries = []
summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
validation_writer.add_summary(summary_pb2.Summary(value=summaries), step)
summaries.append(summary_pb2.Summary.Value(tag="accuracy_validation", simple_value=float(accuracy)))
return summary_pb2.Summary(value=summaries)
\ No newline at end of file
......@@ -3,6 +3,7 @@ from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
from .ExperimentAnalizer import ExperimentAnalizer
from .SoftmaxAnalizer import SoftmaxAnalizer
# gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')]
......
......@@ -87,18 +87,18 @@ def main():
batch_size=VALIDATION_BATCH_SIZE)
# Preparing the architecture
cnn = True
cnn = False
if cnn:
architecture = Chopra(seed=SEED)
#architecture = Lenet(seed=SEED)
#architecture = Dummy(seed=SEED)
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS)
trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS, prefetch=False, temp_dir="cnn")
trainer.train(train_data_shuffler, validation_data_shuffler)
#trainer.train(train_data_shuffler)
else:
mlp = MLP(10, hidden_layers=[15, 20])
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
trainer = Trainer(architecture=mlp, loss=loss, iterations=ITERATIONS)
trainer = Trainer(architecture=mlp, loss=loss, iterations=ITERATIONS, temp_dir="./LOGS/dnn")
trainer.train(train_data_shuffler, validation_data_shuffler)
......@@ -12,6 +12,7 @@ import numpy
import os
import bob.io.base
import bob.core
from ..analyzers import SoftmaxAnalizer
from tensorflow.core.framework import summary_pb2
logger = bob.core.log.setup("bob.learn.tensorflow")
......@@ -34,6 +35,11 @@ class Trainer(object):
iterations=5000,
snapshot=100,
prefetch=False,
## Analizer
analizer = SoftmaxAnalizer(),
verbosity_level=2):
"""
......@@ -77,6 +83,12 @@ class Trainer(object):
self.validation_graph = None
self.validation_summary_writter = None
# Analizer
self.analizer = analizer
self.thread_pool = None
self.enqueue_op = None
bob.core.log.set_verbosity_level(logger, verbosity_level)
def compute_graph(self, data_shuffler, name=""):
......@@ -103,7 +115,7 @@ class Trainer(object):
shapes=[placeholder_data.get_shape().as_list()[1:], []])
# Fetching the place holders from the queue
enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)
# Creating the architecture for train and validation
......@@ -137,14 +149,15 @@ class Trainer(object):
def __fit(self, session, step):
if self.prefetch:
raise ValueError("Prefetch not implemented for such trainer")
_, l, lr, summary = session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train], feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step)
logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step)
def __compute_validation(self, session, data_shuffler, step):
......@@ -166,33 +179,28 @@ class Trainer(object):
tf.scalar_summary('lr', self.learning_rate, name="train")
return tf.merge_all_summaries()
"""
def start_thread(self):
def start_thread(self, session):
threads = []
for n in range(1):
t = threading.Thread(target=self.load_and_enqueue)
t = threading.Thread(target=self.load_and_enqueue, args=(session, ))
t.daemon = True # thread will close when parent quits
t.start()
threads.append(t)
return threads
def load_and_enqueue(self):
def load_and_enqueue(self, session):
"""
Injecting data in the place holder queue
"""
#while not thread_pool.should_stop():
#for i in range(self.iterations):
while not thread_pool.should_stop():
train_data, train_labels = train_data_shuffler.get_batch()
while not self.thread_pool.should_stop():
train_data, train_labels = self.train_data_shuffler.get_batch()
train_placeholder_data, train_placeholder_labels = self.train_data_shuffler.get_placeholders()
feed_dict = {train_placeholder_data: train_data,
train_placeholder_labels: train_labels}
session.run(enqueue_op, feed_dict=feed_dict)
"""
session.run(self.enqueue_op, feed_dict=feed_dict)
def train(self, train_data_shuffler, validation_data_shuffler=None):
"""
......@@ -230,38 +238,22 @@ class Trainer(object):
tf.initialize_all_variables().run()
# Start a thread to enqueue data asynchronously, and hide I/O latency.
#thread_pool = tf.train.Coordinator()
#tf.train.start_queue_runners(coord=thread_pool)
#threads = start_thread()
if self.prefetch:
self.thread_pool = tf.train.Coordinator()
tf.train.start_queue_runners(coord=self.thread_pool)
threads = self.start_thread(session)
# TENSOR BOARD SUMMARY
self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
for step in range(self.iterations):
self.__fit(session, step)
if validation_data_shuffler is not None and step % self.snapshot == 0:
self.__compute_validation(session, validation_data_shuffler, step)
# validation_data, validation_labels = validation_data_shuffler.get_batch()
# feed_dict = {validation_placeholder_data: validation_data,
# validation_placeholder_labels: validation_labels}
#l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
#l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
#import ipdb; ipdb.set_trace();
# l = session.run(loss_validation, feed_dict=feed_dict)
# summaries = []
# summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
# validation_writer.add_summary(summary_pb2.Summary(value=summaries), step)
#l = session.run([loss_validation], feed_dict=feed_dict)
#accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
#validation_writer.add_summary(summary, step)
#print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
# print "Step {0}. Loss = {1}".format(step, l)
if self.analizer is not None:
self.validation_summary_writter.add_summary(self.analizer(
validation_data_shuffler, self.architecture, session), step)
logger.info("Training finally finished")
......@@ -272,9 +264,7 @@ class Trainer(object):
self.architecture.save(hdf5)
del hdf5
# now they should definetely stop
#thread_pool.request_stop()
#thread_pool.join(threads)
if self.prefetch:
# now they should definetely stop
self.thread_pool.request_stop()
self.thread_pool.join(threads)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment