Created mechanism to amend networks

parent 422d8e02
Pipeline #12947 failed with stages
in 6 minutes and 21 seconds
......@@ -23,7 +23,7 @@ step2_path = os.path.join(directory, "step2")
slim = tf.contrib.slim
def base_network(train_data_shuffler, reuse=False, get_embedding=False):
def base_network(train_data_shuffler, reuse=False, get_embedding=False, trainable=True):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
......@@ -33,11 +33,11 @@ def base_network(train_data_shuffler, reuse=False, get_embedding=False):
# Creating a random network
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer, reuse=reuse)
weights_initializer=initializer, reuse=reuse, trainable=trainable)
graph = slim.max_pool2d(graph, [4, 4], scope='pool1')
graph = slim.flatten(graph, scope='flatten1')
graph = slim.fully_connected(graph, 30, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
weights_initializer=initializer, reuse=reuse, trainable=trainable)
if get_embedding:
graph = graph
......@@ -95,17 +95,11 @@ def test_trainable_variables():
writer.close()
# 1 - Create
# 2 - Initialize
# 3 - Minimize with certain variables
# 4 - Load the last checkpoint
######## 1 - BASE NETWORK #########
######## BASE NETWORK #########
tfrecords_filename = "mnist_train.tfrecords"
#create_tf_record(tfrecords_filename, train_data, train_labels)
create_tf_record(tfrecords_filename, train_data, train_labels)
filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=1, name="input")
# Doing the first training
......@@ -128,21 +122,22 @@ def test_trainable_variables():
)
trainer.train()
conv1_trained = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1')[0].eval(session=trainer.session)[0]
# Saving the cov after first training
conv1_after_first_train = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1')[0].eval(session=trainer.session)[0]
del trainer
del filename_queue
del train_data_shuffler
tf.reset_default_graph()
##### Creating an amendment network
######## 2 - AMEMDING NETWORK #########
filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=1, name="input")
train_data_shuffler = TFRecord(filename_queue=filename_queue,
batch_size=batch_size)
graph = base_network(train_data_shuffler, get_embedding=True)
# Here I'm creating the base network not trainable
graph = base_network(train_data_shuffler, get_embedding=True, trainable=False)
graph = amendment_network(graph)
loss = MeanSoftMaxLoss(add_regularization_losses=False)
......@@ -151,7 +146,6 @@ def test_trainable_variables():
analizer=None,
temp_dir=step2_path)
learning_rate = constant(0.01, name="regular_lr")
trainer.create_network_from_scratch(graph=graph,
loss=loss,
......@@ -159,45 +153,18 @@ def test_trainable_variables():
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
)
conv1_before_load = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1')[0].eval(session=trainer.session)[0]
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1') + \
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='fc1')
saver = tf.train.Saver(var_list)
saver.restore(trainer.session, os.path.join(step1_path, "model.ckp"))
# Loading two layers from the "old" model
external_model = os.path.join(step1_path, "model.ckp")
trainer.load_variables_from_external_model(external_model, var_list=['conv1', 'fc1'])
conv1_restored = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1')[0].eval(session=trainer.session)[0]
assert numpy.allclose(conv1_after_first_train, conv1_restored)
trainer.train()
conv1_after_train = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1')[0].eval(session=trainer.session)[0]
print(conv1_trained - conv1_before_load)
print(conv1_trained - conv1_restored)
print(conv1_trained - conv1_after_train)
import ipdb; ipdb.set_trace();
x = 0
# Second round of training
trainer.train()
conv1_after_second_train = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='conv1')[0].eval(session=trainer.session)[0]
#var_list = tf.get_collection(tf.GraphKeys.VARIABLES, scope='fc1') + tf.get_collection(tf.GraphKeys.VARIABLES, scope='logits')
#optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss, global_step=global_step, var_list=var_list)
#print("Go ...")
"""
last_iteration = numpy.sum(tf.trainable_variables()[0].eval(session=session)[0])
for i in range(10):
_, l = session.run([optimizer, loss])
current_iteration = numpy.sum(tf.trainable_variables()[0].eval(session=session)[0])
print numpy.abs(current_iteration - last_iteration)
current_iteration = last_iteration
print l
thread_pool.request_stop()
"""
#x = 0
# Since conv1 was set as NON TRAINABLE, both have to match
assert numpy.allclose(conv1_after_first_train, conv1_after_second_train)
......@@ -317,6 +317,28 @@ class Trainer(object):
else:
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def load_variables_from_external_model(self, file_name, var_list):
"""
Load a set of variables from a given model and update them in the current one
** Parameters **
file_name:
Name of the tensorflow model to be loaded
var_list:
List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
"""
assert len(var_list)>0
tf_varlist = []
for v in var_list:
tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)
saver = tf.train.Saver(tf_varlist)
saver.restore(self.session, file_name)
def create_network_from_file(self, file_name, clear_devices=True):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment