Commit 7205709b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Transfer learning for Siamese nets

parent e1f265e6
Pipeline #13497 passed with stages
in 22 minutes and 40 seconds
......@@ -82,6 +82,14 @@ class Logits(estimator.Estimator):
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
......
......@@ -19,6 +19,10 @@ import logging
logger = logging.getLogger("bob.learn")
from bob.learn.tensorflow.network.utils import append_logits
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
class Siamese(estimator.Estimator):
"""
NN estimator for Siamese networks
......@@ -39,6 +43,14 @@ class Siamese(estimator.Estimator):
return loss_set_of_ops(logits, labels)
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
**Parameters**
architecture:
Pointer to a function that builds the graph.
......@@ -51,9 +63,6 @@ class Siamese(estimator.Estimator):
config:
n_classes:
Number of classes of your problem. The logits will be appended in this class
loss_op:
Pointer to a function that computes the loss.
......@@ -66,13 +75,26 @@ class Siamese(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function
(please see https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
architecture=None,
optimizer=None,
config=None,
n_classes=0,
loss_op=None,
model_dir="",
validation_batch_size=None,
......@@ -82,7 +104,6 @@ class Siamese(estimator.Estimator):
self.architecture = architecture
self.optimizer=optimizer
self.n_classes=n_classes
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
......@@ -96,9 +117,6 @@ class Siamese(estimator.Estimator):
if self.loss_op is None:
raise ValueError("Please specify a function to build the loss !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
......@@ -112,22 +130,17 @@ class Siamese(estimator.Estimator):
if not 'left' in features.keys() or not 'right' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the keys `left` and `right` ")
# Building one graph
prelogits_left = self.architecture(features['left'], is_trainable=is_trainable)[0]
prelogits_right = self.architecture(features['right'], reuse=True, is_trainable=is_trainable)[0]
if self.extra_checkpoint is not None:
import ipdb; ipdb.set_trace();
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
dict({"Dummy/": "Dummy/"}))
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
dict({"Dummy/": "Dummy1/"}))
self.extra_checkpoint["scopes"])
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_left, prelogits_right, labels)
self.loss = self.loss_op(prelogits_left, prelogits_left, labels)
# Configure the Training Op (for TRAIN mode)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
......
......@@ -19,6 +19,7 @@ def dummy(inputs, reuse=False, is_trainable=True):
end_points = dict()
with tf.variable_scope('Dummy', reuse=reuse):
initializer = tf.contrib.layers.xavier_initializer()
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
......
......@@ -27,7 +27,7 @@ tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.001
learning_rate = 0.0001
data_shape = (250, 250, 3) # size of atnt images
output_shape = (50, 50)
data_type = tf.float32
......@@ -69,7 +69,6 @@ def test_siamesetrainer():
trainer = Siamese(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=contrastive_loss,
validation_batch_size=validation_batch_size)
run_siamesetrainer(trainer)
......@@ -84,55 +83,38 @@ def test_siamesetrainer():
def test_siamesetrainer_transfer():
def logits_input_fn():
return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs)
return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape)
# Trainer logits first than siamese
try:
# LOGISTS
#logits_trainer = Logits(model_dir=model_dir,
# architecture=dummy,
# optimizer=tf.train.GradientDescentOptimizer(learning_rate),
# n_classes=10,
# loss_op=mean_cross_entropy_loss,
# embedding_validation=False,
# validation_batch_size=validation_batch_size)
#logits_trainer.train(logits_input_fn, steps=steps)
# Checking if the centers were updated
sess = tf.Session()
checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
saver.restore(sess, tf.train.latest_checkpoint(model_dir))
conv1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Dummy/conv1/weights:0")[0]
print(conv1.eval(sess))
tf.reset_default_graph()
import ipdb; ipdb.set_trace();
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": [dict({"Dummy/": "Dummy/"}),
dict({"Dummy/": "Dummy1/"})],
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
#del logits_trainer
# Checking if the centers were updated
# LOGISTS
logits_trainer = Logits(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=2,
loss_op=mean_cross_entropy_loss,
embedding_validation=False,
validation_batch_size=validation_batch_size)
logits_trainer.train(logits_input_fn, steps=steps)
# NOW THE FUCKING SIAMESE
trainer = Siamese(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=contrastive_loss,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint)
#extra_checkpoint=None
run_siamesetrainer(trainer)
finally:
try:
#shutil.rmtree(model_dir, ignore_errors=True)
#shutil.rmtree(model_dir_adapted, ignore_errors=True)
pass
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(model_dir_adapted, ignore_errors=True)
except Exception:
pass
......@@ -157,7 +139,7 @@ def run_siamesetrainer(trainer):
scaffold=tf.train.Scaffold(),
summary_writer=tf.summary.FileWriter(model_dir) )]
trainer.train(input_fn, steps=steps, hooks=hooks)
trainer.train(input_fn, steps=1, hooks=hooks)
acc = trainer.evaluate(input_validation_fn)
assert acc['accuracy'] > 0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment