Commit b2d5c736 authored by Tiago Pereira's avatar Tiago Pereira

Set the training from file

parent 47480241
......@@ -6,7 +6,7 @@
import numpy
import bob.io.base
import os
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, TripletMemory, SiameseMemory
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, TripletMemory, SiameseMemory, ScaleFactor
from bob.learn.tensorflow.loss import BaseLoss, TripletLoss, ContrastiveLoss
from bob.learn.tensorflow.trainers import Trainer, constant, TripletTrainer, SiameseTrainer
from bob.learn.tensorflow.utils import load_mnist
......@@ -33,17 +33,19 @@ def scratch_network(input_pl):
# Creating a random network
slim = tf.contrib.slim
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=10)
with tf.device("/cpu:0"):
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=10)
scratch = slim.conv2d(input_pl, 10, 3, activation_fn=tf.nn.tanh,
stride=1,
weights_initializer=initializer,
scope='conv1')
scratch = slim.flatten(scratch, scope='flatten1')
scratch = slim.fully_connected(scratch, 10,
weights_initializer=initializer,
activation_fn=None,
scope='fc1')
scratch = slim.conv2d(input_pl, 16, [3, 3], activation_fn=tf.nn.relu,
stride=1,
weights_initializer=initializer,
scope='conv1')
scratch = slim.max_pool2d(scratch, kernel_size=[2, 2], scope='pool1')
scratch = slim.flatten(scratch, scope='flatten1')
scratch = slim.fully_connected(scratch, 10,
weights_initializer=initializer,
activation_fn=None,
scope='fc1')
return scratch
......@@ -58,7 +60,8 @@ def test_cnn_pretrained():
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation)
data_augmentation=data_augmentation,
normalizer=ScaleFactor())
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
directory = "./temp/cnn"
......@@ -81,39 +84,35 @@ def test_cnn_pretrained():
)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(0.01, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.01),
learning_rate=constant(0.1, name="regular_lr"),
optimizer=tf.train.GradientDescentOptimizer(0.1),
)
trainer.train()
accuracy = validate_network(embedding, validation_data, validation_labels)
assert accuracy > 80
tf.reset_default_graph()
del graph
del loss
del trainer
del embedding
# Training the network using a pre trained model
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean, name="loss")
graph = scratch_network(input_pl)
# One graph trainer
trainer = Trainer(train_data_shuffler,
iterations=iterations,
iterations=iterations*3,
analizer=None,
temp_dir=directory
)
trainer.create_network_from_file(os.path.join(directory, "model.ckp"))
import ipdb;
ipdb.set_trace()
trainer.train()
embedding = Embedding(trainer.data_ph, trainer.graph)
accuracy = validate_network(embedding, validation_data, validation_labels)
assert accuracy > 90
shutil.rmtree(directory)
shutil.rmtree(directory2)
del graph
del loss
del trainer
......
......@@ -131,8 +131,6 @@ class Trainer(object):
learning_rate=None,
):
self.saver = tf.train.Saver(var_list=tf.global_variables())
self.data_ph = self.train_data_shuffler("data")
self.label_ph = self.train_data_shuffler("label")
self.graph = graph
......@@ -144,6 +142,10 @@ class Trainer(object):
# TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False, name="global_step")
# Saving all the variables
self.saver = tf.train.Saver(var_list=tf.global_variables())
tf.add_to_collection("global_step", self.global_step)
tf.add_to_collection("graph", self.graph)
......@@ -161,6 +163,7 @@ class Trainer(object):
self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
......@@ -173,15 +176,14 @@ class Trainer(object):
train_data_shuffler: Data shuffler for training
validation_data_shuffler: Data shuffler for validation
"""
#saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta")
self.saver.restore(self.session, model_from_file)
# Loading training graph
self.data_ph = tf.get_collection("data_ph")
self.label_ph = tf.get_collection("label_ph")
self.data_ph = tf.get_collection("data_ph")[0]
self.label_ph = tf.get_collection("label_ph")[0]
self.graph = tf.get_collection("graph")[0]
self.predictor = tf.get_collection("predictor")[0]
......@@ -194,10 +196,7 @@ class Trainer(object):
self.from_scratch = False
# Creating the variables
tf.global_variables_initializer().run(session=self.session)
import ipdb; ipdb.set_trace()
x=0
#tf.global_variables_initializer().run(session=self.session)
def __del__(self):
tf.reset_default_graph()
......@@ -356,7 +355,7 @@ class Trainer(object):
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.saver.save(self.session, path)
self.saver.save(self.session, path, global_step=step)
#self.architecture.save(saver, path)
logger.info("Training finally finished")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment