Refactoring the train

parent f4919f82
...@@ -7,16 +7,15 @@ ...@@ -7,16 +7,15 @@
Neural net work error rates analizer Neural net work error rates analizer
""" """
import numpy import numpy
import bob.measure
from tensorflow.core.framework import summary_pb2 from tensorflow.core.framework import summary_pb2
class SoftmaxAnalizer: class SoftmaxAnalizer(object):
""" """
Analizer. Analizer.
""" """
def __init__(self, data_shuffler, machine, session): def __init__(self):
""" """
Softmax analizer Softmax analizer
...@@ -31,43 +30,25 @@ class SoftmaxAnalizer: ...@@ -31,43 +30,25 @@ class SoftmaxAnalizer:
""" """
self.data_shuffler = data_shuffler self.data_shuffler = None
self.machine = machine self.machine = None
self.session = session self.session = None
def __call__(self, data_shuffler, machine, session):
""" if self.data_shuffler is None:
self.data_shuffler = data_shuffler
placeholder_data, placeholder_labels = data_shuffler.get_placeholders(name="validation") self.machine = machine
graph = machine.compute_graph(placeholder_data) self.session = session
loss_validation = self.loss(validation_graph, validation_placeholder_labels)
tf.scalar_summary('accuracy', loss_validation, name="validation")
merged_validation = tf.merge_all_summaries()
"""
def __call__(self, graph=None):
validation_graph = self.compute_graph(self.data_shuffler, name="validation")
predictions = numpy.argmax(self.session.run(network,
feed_dict={data_node: data[:]}), 1)
return 100. * numpy.sum(predictions == labels) / predictions.shape[0]
# Creating the graph
feature_batch, label_batch = self.data_shuffler.get_placeholders(name="validation_accuracy")
data, labels = self.data_shuffler.get_batch() data, labels = self.data_shuffler.get_batch()
graph = self.machine.compute_graph(feature_batch)
feed_dict = {validation_placeholder_data: validation_data, predictions = numpy.argmax(self.session.run(graph, feed_dict={feature_batch: data[:]}), 1)
validation_placeholder_labels: validation_labels} accuracy = 100. * numpy.sum(predictions == labels) / predictions.shape[0]
# l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
# l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
# import ipdb; ipdb.set_trace();
l = session.run(loss_validation, feed_dict=feed_dict)
summaries = [] summaries = []
summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l))) summaries.append(summary_pb2.Summary.Value(tag="accuracy_validation", simple_value=float(accuracy)))
validation_writer.add_summary(summary_pb2.Summary(value=summaries), step) return summary_pb2.Summary(value=summaries)
\ No newline at end of file
...@@ -3,6 +3,7 @@ from pkgutil import extend_path ...@@ -3,6 +3,7 @@ from pkgutil import extend_path
__path__ = extend_path(__path__, __name__) __path__ = extend_path(__path__, __name__)
from .ExperimentAnalizer import ExperimentAnalizer from .ExperimentAnalizer import ExperimentAnalizer
from .SoftmaxAnalizer import SoftmaxAnalizer
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
......
...@@ -87,18 +87,18 @@ def main(): ...@@ -87,18 +87,18 @@ def main():
batch_size=VALIDATION_BATCH_SIZE) batch_size=VALIDATION_BATCH_SIZE)
# Preparing the architecture # Preparing the architecture
cnn = True cnn = False
if cnn: if cnn:
architecture = Chopra(seed=SEED) architecture = Chopra(seed=SEED)
#architecture = Lenet(seed=SEED) #architecture = Lenet(seed=SEED)
#architecture = Dummy(seed=SEED) #architecture = Dummy(seed=SEED)
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS) trainer = Trainer(architecture=architecture, loss=loss, iterations=ITERATIONS, prefetch=False, temp_dir="cnn")
trainer.train(train_data_shuffler, validation_data_shuffler) trainer.train(train_data_shuffler, validation_data_shuffler)
#trainer.train(train_data_shuffler) #trainer.train(train_data_shuffler)
else: else:
mlp = MLP(10, hidden_layers=[15, 20]) mlp = MLP(10, hidden_layers=[15, 20])
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
trainer = Trainer(architecture=mlp, loss=loss, iterations=ITERATIONS) trainer = Trainer(architecture=mlp, loss=loss, iterations=ITERATIONS, temp_dir="./LOGS/dnn")
trainer.train(train_data_shuffler, validation_data_shuffler) trainer.train(train_data_shuffler, validation_data_shuffler)
...@@ -12,6 +12,7 @@ import numpy ...@@ -12,6 +12,7 @@ import numpy
import os import os
import bob.io.base import bob.io.base
import bob.core import bob.core
from ..analyzers import SoftmaxAnalizer
from tensorflow.core.framework import summary_pb2 from tensorflow.core.framework import summary_pb2
logger = bob.core.log.setup("bob.learn.tensorflow") logger = bob.core.log.setup("bob.learn.tensorflow")
...@@ -34,6 +35,11 @@ class Trainer(object): ...@@ -34,6 +35,11 @@ class Trainer(object):
iterations=5000, iterations=5000,
snapshot=100, snapshot=100,
prefetch=False, prefetch=False,
## Analizer
analizer = SoftmaxAnalizer(),
verbosity_level=2): verbosity_level=2):
""" """
...@@ -77,6 +83,12 @@ class Trainer(object): ...@@ -77,6 +83,12 @@ class Trainer(object):
self.validation_graph = None self.validation_graph = None
self.validation_summary_writter = None self.validation_summary_writter = None
# Analizer
self.analizer = analizer
self.thread_pool = None
self.enqueue_op = None
bob.core.log.set_verbosity_level(logger, verbosity_level) bob.core.log.set_verbosity_level(logger, verbosity_level)
def compute_graph(self, data_shuffler, name=""): def compute_graph(self, data_shuffler, name=""):
...@@ -103,7 +115,7 @@ class Trainer(object): ...@@ -103,7 +115,7 @@ class Trainer(object):
shapes=[placeholder_data.get_shape().as_list()[1:], []]) shapes=[placeholder_data.get_shape().as_list()[1:], []])
# Fetching the place holders from the queue # Fetching the place holders from the queue
enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels]) self.enqueue_op = queue.enqueue_many([placeholder_data, placeholder_labels])
feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size) feature_batch, label_batch = queue.dequeue_many(data_shuffler.batch_size)
# Creating the architecture for train and validation # Creating the architecture for train and validation
...@@ -137,14 +149,15 @@ class Trainer(object): ...@@ -137,14 +149,15 @@ class Trainer(object):
def __fit(self, session, step): def __fit(self, session, step):
if self.prefetch: if self.prefetch:
raise ValueError("Prefetch not implemented for such trainer") _, l, lr, summary = session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train])
else: else:
feed_dict = self.get_feed_dict(self.train_data_shuffler) feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = session.run([self.optimizer, self.training_graph, _, l, lr, summary = session.run([self.optimizer, self.training_graph,
self.learning_rate, self.summaries_train], feed_dict=feed_dict) self.learning_rate, self.summaries_train], feed_dict=feed_dict)
logger.info("Loss training set step={0} = {1}".format(step, l)) logger.info("Loss training set step={0} = {1}".format(step, l))
self.train_summary_writter.add_summary(summary, step) self.train_summary_writter.add_summary(summary, step)
def __compute_validation(self, session, data_shuffler, step): def __compute_validation(self, session, data_shuffler, step):
...@@ -166,33 +179,28 @@ class Trainer(object): ...@@ -166,33 +179,28 @@ class Trainer(object):
tf.scalar_summary('lr', self.learning_rate, name="train") tf.scalar_summary('lr', self.learning_rate, name="train")
return tf.merge_all_summaries() return tf.merge_all_summaries()
def start_thread(self, session):
"""
def start_thread(self):
threads = [] threads = []
for n in range(1): for n in range(1):
t = threading.Thread(target=self.load_and_enqueue) t = threading.Thread(target=self.load_and_enqueue, args=(session, ))
t.daemon = True # thread will close when parent quits t.daemon = True # thread will close when parent quits
t.start() t.start()
threads.append(t) threads.append(t)
return threads return threads
def load_and_enqueue(self, session):
def load_and_enqueue(self): """
Injecting data in the place holder queue Injecting data in the place holder queue
"""
while not self.thread_pool.should_stop():
#while not thread_pool.should_stop(): train_data, train_labels = self.train_data_shuffler.get_batch()
#for i in range(self.iterations): train_placeholder_data, train_placeholder_labels = self.train_data_shuffler.get_placeholders()
while not thread_pool.should_stop():
train_data, train_labels = train_data_shuffler.get_batch()
feed_dict = {train_placeholder_data: train_data, feed_dict = {train_placeholder_data: train_data,
train_placeholder_labels: train_labels} train_placeholder_labels: train_labels}
session.run(enqueue_op, feed_dict=feed_dict) session.run(self.enqueue_op, feed_dict=feed_dict)
"""
def train(self, train_data_shuffler, validation_data_shuffler=None): def train(self, train_data_shuffler, validation_data_shuffler=None):
""" """
...@@ -230,38 +238,22 @@ class Trainer(object): ...@@ -230,38 +238,22 @@ class Trainer(object):
tf.initialize_all_variables().run() tf.initialize_all_variables().run()
# Start a thread to enqueue data asynchronously, and hide I/O latency. # Start a thread to enqueue data asynchronously, and hide I/O latency.
#thread_pool = tf.train.Coordinator() if self.prefetch:
#tf.train.start_queue_runners(coord=thread_pool) self.thread_pool = tf.train.Coordinator()
#threads = start_thread() tf.train.start_queue_runners(coord=self.thread_pool)
threads = self.start_thread(session)
# TENSOR BOARD SUMMARY # TENSOR BOARD SUMMARY
self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph) self.train_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
for step in range(self.iterations): for step in range(self.iterations):
self.__fit(session, step) self.__fit(session, step)
if validation_data_shuffler is not None and step % self.snapshot == 0: if validation_data_shuffler is not None and step % self.snapshot == 0:
self.__compute_validation(session, validation_data_shuffler, step) self.__compute_validation(session, validation_data_shuffler, step)
if self.analizer is not None:
# validation_data, validation_labels = validation_data_shuffler.get_batch() self.validation_summary_writter.add_summary(self.analizer(
validation_data_shuffler, self.architecture, session), step)
# feed_dict = {validation_placeholder_data: validation_data,
# validation_placeholder_labels: validation_labels}
#l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
#l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
#import ipdb; ipdb.set_trace();
# l = session.run(loss_validation, feed_dict=feed_dict)
# summaries = []
# summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
# validation_writer.add_summary(summary_pb2.Summary(value=summaries), step)
#l = session.run([loss_validation], feed_dict=feed_dict)
#accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
#validation_writer.add_summary(summary, step)
#print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
# print "Step {0}. Loss = {1}".format(step, l)
logger.info("Training finally finished") logger.info("Training finally finished")
...@@ -272,9 +264,7 @@ class Trainer(object): ...@@ -272,9 +264,7 @@ class Trainer(object):
self.architecture.save(hdf5) self.architecture.save(hdf5)
del hdf5 del hdf5
if self.prefetch:
# now they should definetely stop
# now they should definetely stop self.thread_pool.request_stop()
#thread_pool.request_stop() self.thread_pool.join(threads)
#thread_pool.join(threads)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment