Commit fdac5f7c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'issue-44' into 'master'

Estimator that loads variables from another model

See merge request !27
parents 5b109f10 fe91ec3f
Pipeline #13554 failed with stages
in 2 minutes and 48 seconds
......@@ -16,8 +16,8 @@ from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
......@@ -43,6 +43,15 @@ class Logits(estimator.Estimator):
return loss_set_of_ops(logits, labels)
Variables, scopes... from other models can be loaded by the model_fn.
For that, please, wrap the the path of the OTHER checkpoint and the list
of variables in a dictionary with the key "load_variable_from_checkpoint" an provide them to the keyword `params`:
{"load_variable_from_checkpoint": {"checkpoint_path":"mypath",
"scopes":{"my_scope/": my_scope/}}}
**Parameters**
architecture:
Pointer to a function that builds the graph.
......@@ -70,6 +79,17 @@ class Logits(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
......@@ -81,6 +101,8 @@ class Logits(estimator.Estimator):
embedding_validation=False,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
self.architecture = architecture
......@@ -89,6 +111,7 @@ class Logits(estimator.Estimator):
self.loss_op=loss_op
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -107,9 +130,14 @@ class Logits(estimator.Estimator):
check_features(features)
data = features['data']
key = features['key']
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits = self.architecture(data)[0]
prelogits = self.architecture(data, is_trainable=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
......@@ -118,7 +146,6 @@ class Logits(estimator.Estimator):
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
......@@ -136,11 +163,17 @@ class Logits(estimator.Estimator):
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
# Validation
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
......@@ -156,6 +189,7 @@ class Logits(estimator.Estimator):
super(Logits, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
params=params,
config=config)
......@@ -200,6 +234,10 @@ class LogitsCenterLoss(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
"""
def __init__(self,
......@@ -212,6 +250,8 @@ class LogitsCenterLoss(estimator.Estimator):
alpha=0.9,
factor=0.01,
validation_batch_size=None,
params=None,
extra_checkpoint=None,
):
self.architecture = architecture
......@@ -221,6 +261,7 @@ class LogitsCenterLoss(estimator.Estimator):
self.factor = factor
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -237,17 +278,25 @@ class LogitsCenterLoss(estimator.Estimator):
data = features['data']
key = features['key']
# Building one graph
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
prelogits = self.architecture(data)[0]
logits = append_logits(prelogits, n_classes)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
......@@ -255,20 +304,22 @@ class LogitsCenterLoss(estimator.Estimator):
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
......@@ -277,6 +328,7 @@ class LogitsCenterLoss(estimator.Estimator):
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
......
......@@ -13,12 +13,16 @@ import time
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
#from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
from bob.learn.tensorflow.network.utils import append_logits
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
class Siamese(estimator.Estimator):
"""
NN estimator for Siamese networks
......@@ -39,6 +43,14 @@ class Siamese(estimator.Estimator):
return loss_set_of_ops(logits, labels)
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
**Parameters**
architecture:
Pointer to a function that builds the graph.
......@@ -51,9 +63,6 @@ class Siamese(estimator.Estimator):
config:
n_classes:
Number of classes of your problem. The logits will be appended in this class
loss_op:
Pointer to a function that computes the loss.
......@@ -66,23 +75,38 @@ class Siamese(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function
(please see https://www.tensorflow.org/extend/estimators for more info)
extra_checkpoint: dict()
In case you want to use other model to initialize some variables.
This argument should be in the following format
extra_checkpoint = {"checkpoint_path": <YOUR_CHECKPOINT>,
"scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
"is_trainable": <IF_THOSE_LOADED_VARIABLES_ARE_TRAINABLE>
}
"""
def __init__(self,
architecture=None,
optimizer=None,
config=None,
n_classes=0,
loss_op=None,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
self.architecture = architecture
self.optimizer=optimizer
self.n_classes=n_classes
self.loss_op=loss_op
self.loss = None
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -93,26 +117,34 @@ class Siamese(estimator.Estimator):
if self.loss_op is None:
raise ValueError("Please specify a function to build the loss !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
if mode == tf.estimator.ModeKeys.TRAIN:
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# The input function needs to have dictionary pair with the `left` and `right` keys
if not 'left' in features.keys() or not 'right' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the keys `left` and `right` ")
# Building one graph
prelogits_left = self.architecture(features['left'])[0]
prelogits_right = self.architecture(features['right'], reuse=True)[0]
prelogits_left = self.architecture(features['left'], is_trainable=is_trainable)[0]
prelogits_right = self.architecture(features['right'], reuse=True, is_trainable=is_trainable)[0]
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(prelogits_left, prelogits_right, labels)
self.loss = self.loss_op(prelogits_left, prelogits_left, labels)
# Configure the Training Op (for TRAIN mode)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
......@@ -135,5 +167,6 @@ class Siamese(estimator.Estimator):
super(Siamese, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
params=params,
config=config)
......@@ -2,12 +2,22 @@
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
def check_features(features):
if not 'data' in features.keys() or not 'key' in features.keys():
raise ValueError("The input function needs to contain a dictionary with the keys `data` and `key` ")
return True
def is_trainable_checkpoint(params):
if not "is_trainable" in params:
raise ValueError("Param `is_trainable` is missing in `load_variable_from_checkpoint` dictionary")
return params["is_trainable"]
from .Logits import Logits, LogitsCenterLoss
from .Siamese import Siamese
from .Triplet import Triplet
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
def dummy(inputs, reuse=False):
def dummy(inputs, reuse=False, is_trainable=True):
"""
Create all the necessary variables for this CNN
......@@ -19,10 +19,12 @@ def dummy(inputs, reuse=False):
end_points = dict()
with tf.variable_scope('Dummy', reuse=reuse):
initializer = tf.contrib.layers.xavier_initializer()
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer)
weights_initializer=initializer,
trainable=is_trainable)
end_points['conv1'] = graph
graph = slim.max_pool2d(graph, [4, 4], scope='pool1')
......@@ -34,8 +36,10 @@ def dummy(inputs, reuse=False):
graph = slim.fully_connected(graph, 50,
weights_initializer=initializer,
activation_fn=None,
scope='fc1')
scope='fc1',
trainable=is_trainable)
end_points['fc1'] = graph
return graph, end_points
......@@ -176,10 +176,10 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
if not trainer.embedding_validation:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
assert acc['accuracy'] > 0.40
else:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
assert acc['accuracy'] > 0.40
# Cleaning up
tf.reset_default_graph()
......
......@@ -5,13 +5,16 @@
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Siamese
from bob.learn.tensorflow.estimators import Siamese, Logits
from bob.learn.tensorflow.dataset.siamese_image import shuffle_data_and_labels_image_augmentation as siamese_batch
from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation as single_batch
from bob.learn.tensorflow.loss import contrastive_loss
from bob.learn.tensorflow.loss import contrastive_loss, mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
from .test_estimator_transfer import dummy_adapted
import pkg_resources
import numpy
......@@ -22,8 +25,9 @@ import os
tfrecord_train = "./train_mnist.tfrecord"
tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.001
learning_rate = 0.0001
data_shape = (250, 250, 3) # size of atnt images
output_shape = (50, 50)
data_type = tf.float32
......@@ -59,13 +63,12 @@ labels = [0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1]
def test_logitstrainer():
def test_siamesetrainer():
# Trainer logits
try:
trainer = Siamese(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=contrastive_loss,
validation_batch_size=validation_batch_size)
run_siamesetrainer(trainer)
......@@ -77,6 +80,45 @@ def test_logitstrainer():
pass
def test_siamesetrainer_transfer():
def logits_input_fn():
return single_batch(filenames, labels, data_shape, data_type, batch_size, epochs=epochs, output_shape=output_shape)
# Trainer logits first than siamese
try:
extra_checkpoint = {"checkpoint_path":model_dir,
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
# LOGISTS
logits_trainer = Logits(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=2,
loss_op=mean_cross_entropy_loss,
embedding_validation=False,
validation_batch_size=validation_batch_size)
logits_trainer.train(logits_input_fn, steps=steps)
# NOW THE FUCKING SIAMESE
trainer = Siamese(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
loss_op=contrastive_loss,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint)
run_siamesetrainer(trainer)
finally:
try:
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(model_dir_adapted, ignore_errors=True)
except Exception:
pass
def run_siamesetrainer(trainer):
# Cleaning up
......@@ -97,7 +139,7 @@ def run_siamesetrainer(trainer):
scaffold=tf.train.Scaffold(),
summary_writer=tf.summary.FileWriter(model_dir) )]
trainer.train(input_fn, steps=steps, hooks=hooks)
trainer.train(input_fn, steps=1, hooks=hooks)
acc = trainer.evaluate(input_validation_fn)
assert acc['accuracy'] > 0.5
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss
from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation
from bob.learn.tensorflow.dataset import append_image_augmentation
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
from .test_estimator_onegraph import run_logitstrainer_mnist
import numpy
import shutil
import os
tfrecord_train = "./train_mnist.tfrecord"
tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.1
data_shape = (28, 28, 1) # size of atnt images
data_type = tf.float32
batch_size = 16
validation_batch_size = 250
epochs = 2
steps = 5000
def dummy_adapted(inputs, reuse=False, is_trainable=False):
"""
Create all the necessary variables for this CNN
**Parameters**
inputs:
reuse:
"""
slim = tf.contrib.slim
graph, end_points = dummy(inputs, reuse=reuse, is_trainable=is_trainable)
initializer = tf.contrib.layers.xavier_initializer()
with tf.variable_scope('Adapted', reuse=reuse):
graph = slim.fully_connected(graph, 50,
weights_initializer=initializer,
activation_fn=tf.nn.relu,
scope='fc2')
end_points['fc2'] = graph
graph = slim.fully_connected(graph, 25,
weights_initializer=initializer,
activation_fn=None,
scope='fc3')
end_points['fc3'] = graph
return graph, end_points
def test_logitstrainer():
# Trainer logits
try:
embedding_validation = False
trainer = Logits(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
run_logitstrainer_mnist(trainer, augmentation=True)
del trainer
## Again
extra_checkpoint = {"checkpoint_path":"./temp",
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
trainer = Logits(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint
)
run_logitstrainer_mnist(trainer, augmentation=True)
finally:
try:
os.unlink(tfrecord_train)
os.unlink(tfrecord_validation)
shutil.rmtree(model_dir,