Commit 7fb73d81 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented transfer mechanism

parent aeaf3303
......@@ -16,8 +16,8 @@ from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
from . import check_features
from . import check_features, is_trainable_checkpoint
import logging
logger = logging.getLogger("bob.learn")
......@@ -43,6 +43,15 @@ class Logits(estimator.Estimator):
return loss_set_of_ops(logits, labels)
Variables, scopes... from other models can be loaded by the model_fn.
For that, please, wrap the the path of the OTHER checkpoint and the list
of variables in a dictionary with the key "load_variable_from_checkpoint" an provide them to the keyword `params`:
{"load_variable_from_checkpoint": {"checkpoint_path":"mypath",
"scopes":{"my_scope/": my_scope/}}}
**Parameters**
architecture:
Pointer to a function that builds the graph.
......@@ -70,6 +79,9 @@ class Logits(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
"""
def __init__(self,
......@@ -81,6 +93,8 @@ class Logits(estimator.Estimator):
embedding_validation=False,
model_dir="",
validation_batch_size=None,
params=None,
extra_checkpoint=None
):
self.architecture = architecture
......@@ -89,6 +103,7 @@ class Logits(estimator.Estimator):
self.loss_op=loss_op
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -107,18 +122,36 @@ class Logits(estimator.Estimator):
check_features(features)
data = features['data']
key = features['key']
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
# Building one graph
prelogits = self.architecture(data)[0]
prelogits = self.architecture(data, is_trainable=is_trainable)[0]
logits = append_logits(prelogits, n_classes)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
......@@ -131,16 +164,7 @@ class Logits(estimator.Estimator):
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = self.optimizer.minimize(self.loss, global_step=global_step)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
# Validation
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
......@@ -156,6 +180,7 @@ class Logits(estimator.Estimator):
super(Logits, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
params=params,
config=config)
......@@ -200,6 +225,10 @@ class LogitsCenterLoss(estimator.Estimator):
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
params:
Extra params for the model function (please see https://www.tensorflow.org/extend/estimators for more info)
"""
def __init__(self,
......@@ -212,6 +241,8 @@ class LogitsCenterLoss(estimator.Estimator):
alpha=0.9,
factor=0.01,
validation_batch_size=None,
params=None,
extra_checkpoint=None,
):
self.architecture = architecture
......@@ -221,6 +252,7 @@ class LogitsCenterLoss(estimator.Estimator):
self.factor = factor
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
......@@ -237,17 +269,42 @@ class LogitsCenterLoss(estimator.Estimator):
data = features['data']
key = features['key']
# Building one graph
# Building one graph, by default everything is trainable
if self.extra_checkpoint is None:
is_trainable = True
else:
is_trainable = is_trainable_checkpoint(self.extra_checkpoint)
prelogits = self.architecture(data)[0]
logits = append_logits(prelogits, n_classes)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
# Loading variables from some model just in case
if self.extra_checkpoint is not None:
tf.contrib.framework.init_from_checkpoint(self.extra_checkpoint["checkpoint_path"],
self.extra_checkpoint["scopes"])
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
......@@ -255,27 +312,11 @@ class LogitsCenterLoss(estimator.Estimator):
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
def dummy(inputs, reuse=False):
def dummy(inputs, reuse=False, is_trainable=True):
"""
Create all the necessary variables for this CNN
......@@ -22,7 +22,8 @@ def dummy(inputs, reuse=False):
initializer = tf.contrib.layers.xavier_initializer()
graph = slim.conv2d(inputs, 10, [3, 3], activation_fn=tf.nn.relu, stride=1, scope='conv1',
weights_initializer=initializer)
weights_initializer=initializer,
trainable=is_trainable)
end_points['conv1'] = graph
graph = slim.max_pool2d(graph, [4, 4], scope='pool1')
......@@ -34,8 +35,10 @@ def dummy(inputs, reuse=False):
graph = slim.fully_connected(graph, 50,
weights_initializer=initializer,
activation_fn=None,
scope='fc1')
scope='fc1',
trainable=is_trainable)
end_points['fc1'] = graph
return graph, end_points
......@@ -176,10 +176,10 @@ def run_logitstrainer_mnist(trainer, augmentation=False):
if not trainer.embedding_validation:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
assert acc['accuracy'] > 0.40
else:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
assert acc['accuracy'] > 0.40
# Cleaning up
tf.reset_default_graph()
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss
from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation
from bob.learn.tensorflow.dataset import append_image_augmentation
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.utils import reproducible
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
from .test_estimator_onegraph import run_logitstrainer_mnist
import numpy
import shutil
import os
tfrecord_train = "./train_mnist.tfrecord"
tfrecord_validation = "./validation_mnist.tfrecord"
model_dir = "./temp"
model_dir_adapted = "./temp2"
learning_rate = 0.1
data_shape = (28, 28, 1) # size of atnt images
data_type = tf.float32
batch_size = 16
validation_batch_size = 250
epochs = 2
steps = 5000
def dummy_adapted(inputs, reuse=False, is_trainable=False):
"""
Create all the necessary variables for this CNN
**Parameters**
inputs:
reuse:
"""
slim = tf.contrib.slim
graph, end_points = dummy(inputs, reuse=reuse, is_trainable=is_trainable)
initializer = tf.contrib.layers.xavier_initializer()
with tf.variable_scope('Adapted', reuse=reuse):
graph = slim.fully_connected(graph, 50,
weights_initializer=initializer,
activation_fn=tf.nn.relu,
scope='fc2')
end_points['fc2'] = graph
graph = slim.fully_connected(graph, 25,
weights_initializer=initializer,
activation_fn=None,
scope='fc3')
end_points['fc3'] = graph
return graph, end_points
def test_logitstrainer():
# Trainer logits
try:
embedding_validation = False
trainer = Logits(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
run_logitstrainer_mnist(trainer, augmentation=True)
del trainer
## Again
extra_checkpoint = {"checkpoint_path":"./temp",
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
trainer = Logits(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint
)
run_logitstrainer_mnist(trainer, augmentation=True)
finally:
try:
os.unlink(tfrecord_train)
os.unlink(tfrecord_validation)
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(model_dir_adapted, ignore_errors=True)
pass
except Exception:
pass
def test_logitstrainer_center_loss():
# Trainer logits
try:
embedding_validation = False
trainer = LogitsCenterLoss(model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size)
run_logitstrainer_mnist(trainer, augmentation=True)
del trainer
## Again
extra_checkpoint = {"checkpoint_path":"./temp",
"scopes": dict({"Dummy/": "Dummy/"}),
"is_trainable": False
}
trainer = LogitsCenterLoss(model_dir=model_dir_adapted,
architecture=dummy_adapted,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
extra_checkpoint=extra_checkpoint
)
run_logitstrainer_mnist(trainer, augmentation=True)
finally:
try:
os.unlink(tfrecord_train)
os.unlink(tfrecord_validation)
shutil.rmtree(model_dir, ignore_errors=True)
shutil.rmtree(model_dir_adapted, ignore_errors=True)
except Exception:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment