Skip to content
Snippets Groups Projects

Estimator that loads variables from another model

Merged Tiago de Freitas Pereira requested to merge issue-44 into master
All threads resolved!
7 files
+ 330
35
Compare changes
  • Side-by-side
  • Inline
Files
7
@@ -16,8 +16,8 @@ from bob.learn.tensorflow.network.utils import append_logits
@@ -16,8 +16,8 @@ from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
 
from . import check_features, is_trainable_checkpoint
import logging
import logging
logger = logging.getLogger("bob.learn")
logger = logging.getLogger("bob.learn")
@@ -43,6 +43,15 @@ class Logits(estimator.Estimator):
@@ -43,6 +43,15 @@ class Logits(estimator.Estimator):
return loss_set_of_ops(logits, labels)
return loss_set_of_ops(logits, labels)
 
Variables, scopes... from other models can be loaded by the model_fn.
 
For that, please, wrap the the path of the OTHER checkpoint and the list
 
of variables in a dictionary with the key "load_variable_from_checkpoint" an provide them to the keyword `params`:
 
 
{"load_variable_from_checkpoint": {"checkpoint_path":"mypath",
 
"scopes":{"my_scope/": my_scope/}}}
 
 
 
**Parameters**
**Parameters**
architecture:
architecture:
Pointer to a function that builds the graph.
Pointer to a function that builds the graph.
@@ -70,6 +79,17 @@ class Logits(estimator.Estimator):
@@ -70,6 +79,17 @@ class Logits(estimator.Estimator):
validation_batch_size:
validation_batch_size:
Size of the batch for validation. This value is used when the
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
validation with embeddings is used. This is a hack.
 
 
params:
 
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
 
 
extra_checkpoint: dict()
 
In case you want to use other model to initialize some variables.
 
This argument should be in the following format
 
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
 
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
 
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
 
+3
}
"""
"""
def __init__(self,
def __init__(self,
@@ -81,6 +101,8 @@ class Logits(estimator.Estimator):
@@ -81,6 +101,8 @@ class Logits(estimator.Estimator):
embedding_validation=False,
embedding_validation=False,
model_dir="",
model_dir="",
validation_batch_size=None,
validation_batch_size=None,
 
params=None,
 
extra_checkpoint=None
):
):
self.architecture = architecture
self.architecture = architecture
@@ -89,6 +111,7 @@ class Logits(estimator.Estimator):
@@ -89,6 +111,7 @@ class Logits(estimator.Estimator):
self.loss_op=loss_op
self.loss_op=loss_op
self.loss = None
self.loss = None
self.embedding_validation = embedding_validation
self.embedding_validation = embedding_validation
 
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
raise ValueError("Please specify a function to build the architecture !!")
@@ -107,9 +130,14 @@ class Logits(estimator.Estimator):
@@ -107,9 +130,14 @@ class Logits(estimator.Estimator):
check_features(features)
check_features(features)
data = features['data']
data = features['data']
key = features['key']
key = features['key']
 
 
# Building one graph, by default everything is trainable
 
if self.extra_checkpoint is None:
 
is_trainable = True
 
else:
 
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits = self.architecture(data, is_trainable=is_trainable)[0]
prelogits = self.architecture(data)[0]
logits = append_logits(prelogits, n_classes)
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
if self.embedding_validation:
@@ -118,7 +146,6 @@ class Logits(estimator.Estimator):
@@ -118,7 +146,6 @@ class Logits(estimator.Estimator):
predictions = {
predictions = {
"embeddings": embeddings
"embeddings": embeddings
}
}
else:
else:
predictions = {
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
# Generate predictions (for PREDICT and EVAL mode)
@@ -136,11 +163,17 @@ class Logits(estimator.Estimator):
@@ -136,11 +163,17 @@ class Logits(estimator.Estimator):
# Configure the Training Op (for TRAIN mode)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf.estimator.ModeKeys.TRAIN:
 
if self.extra_checkpoint is not None:
 
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
 
self.extra_checkpoint["scopes"])
 
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
train_op=train_op)
 
 
# Validation
if self.embedding_validation:
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
@@ -156,6 +189,7 @@ class Logits(estimator.Estimator):
@@ -156,6 +189,7 @@ class Logits(estimator.Estimator):
super(Logits, self).__init__(model_fn=_model_fn,
super(Logits, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
model_dir=model_dir,
 
params=params,
config=config)
config=config)
@@ -200,6 +234,10 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -200,6 +234,10 @@ class LogitsCenterLoss(estimator.Estimator):
validation_batch_size:
validation_batch_size:
Size of the batch for validation. This value is used when the
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
validation with embeddings is used. This is a hack.
 
 
params:
 
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
 
"""
"""
def __init__(self,
def __init__(self,
@@ -212,6 +250,8 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -212,6 +250,8 @@ class LogitsCenterLoss(estimator.Estimator):
alpha=0.9,
alpha=0.9,
factor=0.01,
factor=0.01,
validation_batch_size=None,
validation_batch_size=None,
 
params=None,
 
extra_checkpoint=None,
):
):
self.architecture = architecture
self.architecture = architecture
@@ -221,6 +261,7 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -221,6 +261,7 @@ class LogitsCenterLoss(estimator.Estimator):
self.factor = factor
self.factor = factor
self.loss = None
self.loss = None
self.embedding_validation = embedding_validation
self.embedding_validation = embedding_validation
 
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
raise ValueError("Please specify a function to build the architecture !!")
@@ -237,17 +278,25 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -237,17 +278,25 @@ class LogitsCenterLoss(estimator.Estimator):
data = features['data']
data = features['data']
key = features['key']
key = features['key']
# Building one graph
# Building one graph, by default everything is trainable
 
if self.extra_checkpoint is None:
 
is_trainable = True
 
else:
 
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
 
prelogits = self.architecture(data)[0]
prelogits = self.architecture(data)[0]
logits = append_logits(prelogits, n_classes)
logits = append_logits(prelogits, n_classes)
 
# Compute Loss (for both TRAIN and EVAL modes)
 
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
 
alpha=self.alpha, factor=self.factor)
 
if self.embedding_validation:
if self.embedding_validation:
# Compute the embeddings
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
predictions = {
"embeddings": embeddings
"embeddings": embeddings
}
}
else:
else:
predictions = {
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
# Generate predictions (for PREDICT and EVAL mode)
@@ -255,20 +304,22 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -255,20 +304,22 @@ class LogitsCenterLoss(estimator.Estimator):
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
}
if mode == tf.estimator.ModeKeys.PREDICT:
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
self.loss = loss_dict['loss']
centers = loss_dict['centers']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf.estimator.ModeKeys.TRAIN:
 
# Loading variables from some model just in case
 
if self.extra_checkpoint is not None:
 
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
 
self.extra_checkpoint["scopes"])
 
global_step = tf.contrib.framework.get_or_create_global_step()
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
@@ -277,6 +328,7 @@ class LogitsCenterLoss(estimator.Estimator):
@@ -277,6 +328,7 @@ class LogitsCenterLoss(estimator.Estimator):
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
train_op=train_op)
 
if self.embedding_validation:
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
Loading