Commit b2d5c736 authored by Tiago Pereira's avatar Tiago Pereira

Set the training from file

parent 47480241
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import numpy import numpy
import bob.io.base import bob.io.base
import os import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, TripletMemory, SiameseMemory from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, TripletMemory, SiameseMemory, ScaleFactor
from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss
from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer
from bob.learn.tensorflow.utils import load_mnist from bob.learn.tensorflow.utils import load_mnist
...@@ -33,17 +33,19 @@ def scratch_network(input_pl): ...@@ -33,17 +33,19 @@ def scratch_network(input_pl):
# Creating a random network # Creating a random network
slim = tf.contrib.slim slim = tf.contrib.slim
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=10) with tf.device("/cpu:0"):
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=10)
scratch = slim.conv2d(input_pl, 10, 3, activation_fn=tf.nn.tanh, scratch = slim.conv2d(input_pl, 16, [3, 3], activation_fn=tf.nn.relu,
stride=1, stride=1,
weights_initializer=initializer, weights_initializer=initializer,
scope='conv1') scope='conv1')
scratch = slim.flatten(scratch, scope='flatten1') scratch = slim.max_pool2d(scratch, kernel_size=[2, 2], scope='pool1')
scratch = slim.fully_connected(scratch, 10, scratch = slim.flatten(scratch, scope='flatten1')
weights_initializer=initializer, scratch = slim.fully_connected(scratch, 10,
activation_fn=None, weights_initializer=initializer,
scope='fc1') activation_fn=None,
scope='fc1')
return scratch return scratch
...@@ -58,7 +60,8 @@ def test_cnn_pretrained(): ...@@ -58,7 +60,8 @@ def test_cnn_pretrained():
train_data_shuffler = Memory(train_data, train_labels, train_data_shuffler = Memory(train_data, train_labels,
input_shape=[None, 28, 28, 1], input_shape=[None, 28, 28, 1],
batch_size=batch_size, batch_size=batch_size,
data_augmentation=data_augmentation) data_augmentation=data_augmentation,
normalizer=ScaleFactor())
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1)) validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
directory = "./temp/cnn" directory = "./temp/cnn"
...@@ -81,39 +84,35 @@ def test_cnn_pretrained(): ...@@ -81,39 +84,35 @@ def test_cnn_pretrained():
) )
trainer.create_network_from_scratch(graph=graph, trainer.create_network_from_scratch(graph=graph,
loss=loss, loss=loss,
learning_rate=constant(0.01, name="regular_lr"), learning_rate=constant(0.1, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01), optimizer=tf.train.GradientDescentOptimizer(0.1),
) )
trainer.train() trainer.train()
accuracy = validate_network(embedding, validation_data, validation_labels) accuracy = validate_network(embedding, validation_data, validation_labels)
assert accuracy > 80 assert accuracy > 80
tf.reset_default_graph()
del graph del graph
del loss del loss
del trainer del trainer
del embedding
# Training the network using a pre trained model # Training the network using a pre trained model
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean, name="loss") loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean, name="loss")
graph = scratch_network(input_pl)
# One graph trainer # One graph trainer
trainer = Trainer(train_data_shuffler, trainer = Trainer(train_data_shuffler,
iterations=iterations, iterations=iterations*3,
analizer=None, analizer=None,
temp_dir=directory temp_dir=directory
) )
trainer.create_network_from_file(os.path.join(directory, "model.ckp")) trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
import ipdb;
ipdb.set_trace()
trainer.train() trainer.train()
embedding = Embedding(trainer.data_ph, trainer.graph)
accuracy = validate_network(embedding, validation_data, validation_labels) accuracy = validate_network(embedding, validation_data, validation_labels)
assert accuracy > 90 assert accuracy > 90
shutil.rmtree(directory) shutil.rmtree(directory)
shutil.rmtree(directory2)
del graph
del loss del loss
del trainer del trainer
......
...@@ -131,8 +131,6 @@ class Trainer(object): ...@@ -131,8 +131,6 @@ class Trainer(object):
learning_rate=None, learning_rate=None,
): ):
self.saver = tf.train.Saver(var_list=tf.global_variables())
self.data_ph = self.train_data_shuffler("data") self.data_ph = self.train_data_shuffler("data")
self.label_ph = self.train_data_shuffler("label") self.label_ph = self.train_data_shuffler("label")
self.graph = graph self.graph = graph
...@@ -144,6 +142,10 @@ class Trainer(object): ...@@ -144,6 +142,10 @@ class Trainer(object):
# TODO: find an elegant way to provide this as a parameter of the trainer # TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step") self.global_step = tf.Variable(0, trainable=False, name="global_step")
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
tf.add_to_collection("global_step", self.global_step) tf.add_to_collection("global_step", self.global_step)
tf.add_to_collection("graph", self.graph) tf.add_to_collection("graph", self.graph)
...@@ -161,6 +163,7 @@ class Trainer(object): ...@@ -161,6 +163,7 @@ class Trainer(object):
self.summaries_train = self.create_general_summary() self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train) tf.add_to_collection("summaries_train", self.summaries_train)
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
...@@ -173,15 +176,14 @@ class Trainer(object): ...@@ -173,15 +176,14 @@ class Trainer(object):
train_data_shuffler: Data shuffler for training train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation validation_data_shuffler: Data shuffler for validation
""" """
#saver = self.architecture.load(self.model_from_file, clear_devices=False) #saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta") self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
self.saver.restore(self.session, model_from_file) self.saver.restore(self.session, model_from_file)
# Loading training graph # Loading training graph
self.data_ph = tf.get_collection("data_ph") self.data_ph = tf.get_collection("data_ph")[0]
self.label_ph = tf.get_collection("label_ph") self.label_ph = tf.get_collection("label_ph")[0]
self.graph = tf.get_collection("graph")[0] self.graph = tf.get_collection("graph")[0]
self.predictor = tf.get_collection("predictor")[0] self.predictor = tf.get_collection("predictor")[0]
...@@ -194,10 +196,7 @@ class Trainer(object): ...@@ -194,10 +196,7 @@ class Trainer(object):
self.from_scratch = False self.from_scratch = False
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) #tf.global_variables_initializer().run(session=self.session)
import ipdb; ipdb.set_trace()
x=0
def __del__(self): def __del__(self):
tf.reset_default_graph() tf.reset_default_graph()
...@@ -356,7 +355,7 @@ class Trainer(object): ...@@ -356,7 +355,7 @@ class Trainer(object):
if step % self.snapshot == 0: if step % self.snapshot == 0:
logger.info("Taking snapshot") logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step)) path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.saver.save(self.session, path) self.saver.save(self.session, path, global_step=step)
#self.architecture.save(saver, path) #self.architecture.save(saver, path)
logger.info("Training finally finished") logger.info("Training finally finished")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment