Still strugling with a general training

parent 14147b9a
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Tue 09 Aug 2016 15:33 CEST
"""
Neural net work error rates analizer
"""
import numpy
import bob.measure
from tensorflow.core.framework import summary_pb2
from scipy.spatial.distance import cosine
class ExperimentAnalizer:
"""
Analizer.
"""
def __init__(self, data_shuffler, machine, session):
"""
Use the CNN as feature extractor for a n-class classification
** Parameters **
data_shuffler:
graph:
session:
convergence_threshold:
convergence_reference: References to analize the convergence. Possible values are `eer`, `far10`, `far10`
"""
self.data_shuffler = data_shuffler
self.machine = machine
self.session = session
placeholder_data, placeholder_labels = data_shuffler.get_placeholders(name="validation")
graph = machine.compute_graph(placeholder_data)
loss_validation = self.loss(validation_graph, validation_placeholder_labels)
tf.scalar_summary('loss', loss_validation, name="validation")
merged_validation = tf.merge_all_summaries()
def __call__(self):
data, labels = self.data_shuffler.get_batch()
feed_dict = {validation_placeholder_data: validation_data,
validation_placeholder_labels: validation_labels}
# l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
# l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
# import ipdb; ipdb.set_trace();
l = session.run(loss_validation, feed_dict=feed_dict)
summaries = []
summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
validation_writer.add_summary(summary_pb2.Summary(value=summaries), step)
...@@ -56,7 +56,7 @@ class MemoryDataShuffler(BaseDataShuffler): ...@@ -56,7 +56,7 @@ class MemoryDataShuffler(BaseDataShuffler):
selected_data = self.data[indexes[0:self.batch_size], :, :, :] selected_data = self.data[indexes[0:self.batch_size], :, :, :]
selected_labels = self.labels[indexes[0:self.batch_size]] selected_labels = self.labels[indexes[0:self.batch_size]]
return selected_data.astype("float32"), selected_labels return selected_data, selected_labels.astype("int64")
def get_pair(self, zero_one_labels=True): def get_pair(self, zero_one_labels=True):
""" """
......
...@@ -11,6 +11,7 @@ import threading ...@@ -11,6 +11,7 @@ import threading
import numpy import numpy
import os import os
import bob.io.base import bob.io.base
from tensorflow.core.framework import summary_pb2
class Trainer(object): class Trainer(object):
...@@ -20,7 +21,7 @@ class Trainer(object): ...@@ -20,7 +21,7 @@ class Trainer(object):
optimizer=tf.train.AdamOptimizer(), optimizer=tf.train.AdamOptimizer(),
use_gpu=False, use_gpu=False,
loss=None, loss=None,
temp_dir="", temp_dir="cnn",
# Learning rate # Learning rate
base_learning_rate=0.001, base_learning_rate=0.001,
...@@ -96,11 +97,14 @@ class Trainer(object): ...@@ -96,11 +97,14 @@ class Trainer(object):
self.weight_decay # Decay step self.weight_decay # Decay step
) )
# Creating directory
bob.io.base.create_directories_safe(self.temp_dir)
# Defining place holders # Defining place holders
train_placeholder_data, train_placeholder_labels = train_data_shuffler.get_placeholders_forprefetch(name="train") train_placeholder_data, train_placeholder_labels = train_data_shuffler.get_placeholders_forprefetch(name="train")
if validation_data_shuffler is not None: #if validation_data_shuffler is not None:
validation_placeholder_data, validation_placeholder_labels = \ # validation_placeholder_data, validation_placeholder_labels = \
validation_data_shuffler.get_placeholders(name="validation") # validation_data_shuffler.get_placeholders(name="validation")
# Defining a placeholder queue for prefetching # Defining a placeholder queue for prefetching
queue = tf.FIFOQueue(capacity=10, queue = tf.FIFOQueue(capacity=10,
dtypes=[tf.float32, tf.int64], dtypes=[tf.float32, tf.int64],
...@@ -118,16 +122,23 @@ class Trainer(object): ...@@ -118,16 +122,23 @@ class Trainer(object):
# Creating graphs and defining the loss # Creating graphs and defining the loss
train_graph = self.architecture.compute_graph(train_feature_batch) train_graph = self.architecture.compute_graph(train_feature_batch)
loss_train = self.loss(train_graph, train_label_batch) loss_train = self.loss(train_graph, train_label_batch)
train_prediction = tf.nn.softmax(train_graph)
if validation_data_shuffler is not None:
validation_graph = self.architecture.compute_graph(validation_placeholder_data)
loss_validation = self.loss(validation_graph, validation_placeholder_labels)
validation_prediction = tf.nn.softmax(validation_graph)
# Preparing the optimizer # Preparing the optimizer
self.optimizer._learning_rate = learning_rate self.optimizer._learning_rate = learning_rate
optimizer = self.optimizer.minimize(loss_train, global_step=tf.Variable(0)) optimizer = self.optimizer.minimize(loss_train, global_step=tf.Variable(0))
# Train summary
tf.scalar_summary('loss', loss_train, name="train")
tf.scalar_summary('lr', learning_rate, name="train")
merged_train = tf.merge_all_summaries()
# Validation
#if validation_data_shuffler is not None:
# validation_graph = self.architecture.compute_graph(validation_placeholder_data)
# loss_validation = self.loss(validation_graph, validation_placeholder_labels)
# tf.scalar_summary('loss', loss_validation, name="validation")
# merged_validation = tf.merge_all_summaries()
print("Initializing !!") print("Initializing !!")
# Training # Training
hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w') hdf5 = bob.io.base.HDF5File(os.path.join(self.temp_dir, 'model.hdf5'), 'w')
...@@ -142,12 +153,15 @@ class Trainer(object): ...@@ -142,12 +153,15 @@ class Trainer(object):
threads = start_thread() threads = start_thread()
train_writer = tf.train.SummaryWriter('./LOGS/train', session.graph) # TENSOR BOARD SUMMARY
train_writer = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'train'), session.graph)
validation_writer = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)
for step in range(self.iterations): for step in range(self.iterations):
_, l, lr, _ = session.run([optimizer, loss_train, _, l, lr, summary = session.run([optimizer, loss_train,
learning_rate, train_prediction]) learning_rate, merged_train])
train_writer.add_summary(summary, step)
if validation_data_shuffler is not None and step % self.snapshot == 0: if validation_data_shuffler is not None and step % self.snapshot == 0:
validation_data, validation_labels = validation_data_shuffler.get_batch() validation_data, validation_labels = validation_data_shuffler.get_batch()
...@@ -155,16 +169,27 @@ class Trainer(object): ...@@ -155,16 +169,27 @@ class Trainer(object):
feed_dict = {validation_placeholder_data: validation_data, feed_dict = {validation_placeholder_data: validation_data,
validation_placeholder_labels: validation_labels} validation_placeholder_labels: validation_labels}
l, predictions = session.run([loss_validation, validation_prediction], feed_dict=feed_dict) #l, predictions = session.run([loss_validation, validation_prediction, ], feed_dict=feed_dict)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0] #l, summary = session.run([loss_validation, merged_validation], feed_dict=feed_dict)
#import ipdb; ipdb.set_trace();
l = session.run(loss_validation, feed_dict=feed_dict)
summaries = []
summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
validation_writer.add_summary(summary_pb2.Summary(value=summaries), step)
print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
#l = session.run([loss_validation], feed_dict=feed_dict)
#accuracy = 100. * numpy.sum(numpy.argmax(predictions, 1) == validation_labels) / predictions.shape[0]
#validation_writer.add_summary(summary, step)
#print "Step {0}. Loss = {1}, Acc Validation={2}".format(step, l, accuracy)
print "Step {0}. Loss = {1}".format(step, l)
train_writer.close() train_writer.close()
self.architecture.save(hdf5)
del hdf5
# now they should definetely stop # now they should definetely stop
thread_pool.request_stop() thread_pool.request_stop()
thread_pool.join(threads) thread_pool.join(threads)
self.architecture.save(hdf5)
del hdf5
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment