Skip to content
Snippets Groups Projects

Estimator that loads variables from another model

Merged Tiago de Freitas Pereira requested to merge issue-44 into master
7 files
+ 330
35
Compare changes
  • Side-by-side
  • Inline
Files
7
@@ -16,8 +16,8 @@ from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
@@ -43,6 +43,15 @@ class Logits(estimator.Estimator):
return loss_set_of_ops(logits, labels)
Variables, scopes... from other models can be loaded by the model_fn.
For that, please, wrap the the path of the OTHER checkpoint and the list
of variables in a dictionary with the key "load_variable_from_checkpoint" an provide them to the keyword `params`:
{"load_variable_from_checkpoint": {"checkpoint_path":"mypath",
"scopes":{"my_scope/": my_scope/}}}
**Parameters**
architecture:
Pointer to a function that builds the graph.
@@ -70,6 +79,17 @@ class Logits(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
@@ -81,6 +101,8 @@ class Logits(estimator.Estimator):
embedding_validation=False,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
self.architecture = architecture
@@ -89,6 +111,7 @@ class Logits(estimator.Estimator):
self.loss_op=loss_op
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
@@ -107,9 +130,14 @@ class Logits(estimator.Estimator):
check_features(features)
data = features['data']
key = features['key']
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits = self.architecture(data)[0]
prelogits = self.architecture(data, is_trainable=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
@@ -118,7 +146,6 @@ class Logits(estimator.Estimator):
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
@@ -136,11 +163,17 @@ class Logits(estimator.Estimator):
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
# Validation
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
@@ -156,6 +189,7 @@ class Logits(estimator.Estimator):
super(Logits, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
params=params,
config=config)
@@ -200,6 +234,10 @@ class LogitsCenterLoss(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
"""
def __init__(self,
@@ -212,6 +250,8 @@ class LogitsCenterLoss(estimator.Estimator):
alpha=0.9,
factor=0.01,
validation_batch_size=None,
params=None,
extra_checkpoint=None,
):
self.architecture = architecture
@@ -221,6 +261,7 @@ class LogitsCenterLoss(estimator.Estimator):
self.factor = factor
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
@@ -237,17 +278,25 @@ class LogitsCenterLoss(estimator.Estimator):
data = features['data']
key = features['key']
# Building one graph
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
prelogits = self.architecture(data)[0]
logits = append_logits(prelogits, n_classes)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
@@ -255,20 +304,22 @@ class LogitsCenterLoss(estimator.Estimator):
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
@@ -277,6 +328,7 @@ class LogitsCenterLoss(estimator.Estimator):
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
Loading