Commit e1f265e6 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented transfer mechanism for Siamese nets [debuging]

parent 7fb73d81
Pipeline #13478 failed with stages
in 19 minutes and 55 seconds
......@@ -13,7 +13,7 @@ import time
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
#from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
......@@ -76,6 +76,8 @@ class Siamese(estimator.Estimator):
loss_op=None,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
self.architecture = architecture
......@@ -83,6 +85,7 @@ class Siamese(estimator.Estimator):
self.n_classes=n_classes
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -97,22 +100,38 @@ class Siamese(estimator.Estimator):
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# The input function needs to have dictionary pair with the `left` and `right` keys
if not 'left' in features.keys() or not 'right' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the keys `left` and `right` ")
# Building one graph
prelogits_left = self.architecture(features['left'])[0]
prelogits_right = self.architecture(features['right'], reuse=True)[0]
prelogits_left = self.architecture(features['left'], is_trainable=is_trainable)[0]
prelogits_right = self.architecture(features['right'], reuse=True, is_trainable=is_trainable)[0]
if self.extra_checkpoint is not None:
import ipdb; ipdb.set_trace();
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
dict({"Dummy/": "Dummy/"}))
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
dict({"Dummy/": "Dummy1/"}))
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_left, prelogits_right, labels)
# Configure the Training Op (for TRAIN mode)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
......@@ -135,5 +154,6 @@ class Siamese(estimator.Estimator):
super(Siamese, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
params=params,
config=config)
......@@ -2,12 +2,22 @@
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
def check_features(features):
if not 'data' in features.keys() or not 'key' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the keys `data` and `key` ")
return True
def is_trainable_checkpoint(params):
if not "is_trainable" in params:
raise ValueError("Param `is_trainable` is missing in `load_variable_from_checkpoint` dictionary")
return params["is_trainable"]
from .Logits import Logits, LogitsCenterLoss
from .Siamese import Siamese
from .Triplet import Triplet
......
......@@ -5,13 +5,16 @@
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Siamese
from bob.learn.tensorflow.estimators import Siamese, Logits
from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch
from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation as single_batch
from bob.learn.tensorflow.loss import contrastive_loss
from bob.learn.tensorflow.loss import contrastive_loss, mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
from .test_estimator_transfer import dummy_adapted
import pkg_resources
import numpy
......@@ -22,6 +25,7 @@ import os
tfrecord_train = "./train_mnist.tfrecord"
tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.001
data_shape = (250, 250, 3) # size of atnt images
......@@ -59,7 +63,7 @@ labels = [0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1]
def test_logitstrainer():
def test_siamesetrainer():
# Trainer logits
try:
trainer = Siamese(model_dir=model_dir,
......@@ -77,6 +81,62 @@ def test_logitstrainer():
pass
def test_siamesetrainer_transfer():
def logits_input_fn():
return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs)
# Trainer logits first than siamese
try:
# LOGISTS
#logits_trainer = Logits(model_dir=model_dir,
# architecture=dummy,
# optimizer=tf.train.GradientDescentOptimizer(learning_rate),
# n_classes=10,
# loss_op=mean_cross_entropy_loss,
# embedding_validation=False,
# validation_batch_size=validation_batch_size)
#logits_trainer.train(logits_input_fn, steps=steps)
# Checking if the centers were updated
sess = tf.Session()
checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
saver.restore(sess, tf.train.latest_checkpoint(model_dir))
conv1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="Dummy/conv1/weights:0")[0]
print(conv1.eval(sess))
tf.reset_default_graph()
import ipdb; ipdb.set_trace();
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": [dict({"Dummy/": "Dummy/"}),
dict({"Dummy/": "Dummy1/"})],
"is_trainable": False
}
#del logits_trainer
# Checking if the centers were updated
# NOW THE FUCKING SIAMESE
trainer = Siamese(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=contrastive_loss,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint)
#extra_checkpoint=None
run_siamesetrainer(trainer)
finally:
try:
#shutil.rmtree(model_dir, ignore_errors=True)
#shutil.rmtree(model_dir_adapted, ignore_errors=True)
pass
except Exception:
pass
def run_siamesetrainer(trainer):
# Cleaning up
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment