Skip to content
Snippets Groups Projects

Estimator that loads variables from another model

Merged Tiago de Freitas Pereira requested to merge issue-44 into master
All threads resolved!
3 files
+ 98
8
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -13,7 +13,7 @@ import time
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
#from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
@@ -76,6 +76,8 @@ class Siamese(estimator.Estimator):
loss_op=None,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
self.architecture = architecture
@@ -83,6 +85,7 @@ class Siamese(estimator.Estimator):
self.n_classes=n_classes
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
@@ -97,22 +100,38 @@ class Siamese(estimator.Estimator):
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# The input function needs to have dictionary pair with the `left` and `right` keys
if not 'left' in features.keys() or not 'right' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the keys `left` and `right` ")
# Building one graph
prelogits_left = self.architecture(features['left'])[0]
prelogits_right = self.architecture(features['right'], reuse=True)[0]
prelogits_left = self.architecture(features['left'], is_trainable=is_trainable)[0]
prelogits_right = self.architecture(features['right'], reuse=True, is_trainable=is_trainable)[0]
if self.extra_checkpoint is not None:
import ipdb; ipdb.set_trace();
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
dict({"Dummy/": "Dummy/"}))
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
dict({"Dummy/": "Dummy1/"}))
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_left, prelogits_right, labels)
# Configure the Training Op (for TRAIN mode)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
@@ -135,5 +154,6 @@ class Siamese(estimator.Estimator):
super(Siamese, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
params=params,
config=config)
Loading