Skip to content
Snippets Groups Projects
Commit ffb39107 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented center loss

parent 41f0f826
Branches
Tags
1 merge request!21Resolve "Adopt to the Estimators API"
Pipeline #
......@@ -32,6 +32,7 @@ def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True):
else:
return loss
def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0.9, factor=0.01):
"""
Implementation of the CrossEntropy + Center Loss from the paper
......@@ -58,7 +59,7 @@ def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0
centers = tf.get_variable('centers', [n_classes, n_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(labels, [-1])
#label = tf.reshape(labels, [-1])
centers_batch = tf.gather(centers, labels)
diff = (1 - alpha) * (centers_batch - prelogits)
centers = tf.scatter_sub(centers, labels, diff)
......@@ -68,7 +69,7 @@ def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0
# Adding the regularizers in the loss
with tf.variable_scope('total_loss'):
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_loss = tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
total_loss = tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES)
loss = dict()
loss['loss'] = total_loss
......
......@@ -2,23 +2,15 @@
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import numpy
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, scale_factor
from bob.learn.tensorflow.network import chopra
from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss, triplet_loss
from bob.learn.tensorflow.test.test_cnn_scratch import validate_network
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.network.utils import append_logits
import tensorflow as tf
from bob.learn.tensorflow.trainers import LogitsTrainer
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer
from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, batch_data_and_labels
from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
import numpy
import shutil
import os
......@@ -37,16 +29,25 @@ epochs = 1
steps = 2000
def test_cnn_trainer():
run_cnn(False)
def test_logitstrainer():
run_logitstrainer(False)
def test_logitstrainer_embedding():
run_logitstrainer(True)
def test_logitstrainer_centerloss():
run_logitstrainer_centerloss(False)
def test_cnn_trainer_embedding():
run_cnn(True)
def test_logitstrainer_centerloss_embedding():
run_logitstrainer_centerloss(True)
def run_cnn(embedding_validation):
def run_logitstrainer(embedding_validation):
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
......@@ -79,6 +80,7 @@ def run_cnn(embedding_validation):
trainer.train(input_fn, steps=steps, hooks=hooks)
if not embedding_validation:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
else:
......@@ -91,10 +93,70 @@ def run_cnn(embedding_validation):
os.unlink(tfrecord_validation)
shutil.rmtree(model_dir)
except Exception:
pass
pass
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
def run_logitstrainer_centerloss(embedding_validation):
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
# Creating tf records for mnist
train_data, train_labels, validation_data, validation_labels = load_mnist()
create_mnist_tfrecord(tfrecord_train, train_data, train_labels, n_samples=6000)
create_mnist_tfrecord(tfrecord_validation, validation_data, validation_labels, n_samples=1000)
try:
# Trainer logits
trainer = LogitsCenterLossTrainer(
model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
factor=0.01
)
def input_fn():
return shuffle_data_and_labels(tfrecord_train, data_shape, data_type,
batch_size, epochs=epochs)
def input_fn_validation():
return batch_data_and_labels(tfrecord_validation, data_shape, data_type,
validation_batch_size, epochs=epochs)
hooks = [LoggerHookEstimator(trainer, 16, 100)]
trainer.train(input_fn, steps=steps, hooks=hooks)
if not embedding_validation:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
else:
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.80
sess = tf.Session()
checkpoint_path = tf.train.get_checkpoint_state(model_dir).model_checkpoint_path
saver = tf.train.import_meta_graph(checkpoint_path + ".meta", clear_devices=True)
saver.restore(sess, tf.train.latest_checkpoint(model_dir))
centers = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="center_loss/centers:0")[0]
assert numpy.sum(numpy.abs(centers.eval(sess))) > 0.0
finally:
try:
os.unlink(tfrecord_train)
os.unlink(tfrecord_validation)
shutil.rmtree(model_dir)
except Exception:
pass
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
......@@ -14,8 +14,8 @@ import time
#logger = bob.core.log.setup("bob.learn.tensorflow")
from bob.learn.tensorflow.network.utils import append_logits
from tensorflow.python.estimator import estimator
from bob.learn.tensorflow.utils import reproducible
from bob.learn.tensorflow.utils import predict_using_tensors
from bob.learn.tensorflow.loss import mean_cross_entropy_center_loss
import logging
......@@ -24,24 +24,41 @@ logger = logging.getLogger("bob.learn")
class LogitsTrainer(estimator.Estimator):
"""
Logits .
NN Trainer whose with logits as last layer
The **architecture** function should follow the following pattern:
def my_beautiful_function(placeholder):
end_points = dict()
graph = convXX(placeholder)
end_points['conv'] = graph
....
return graph, end_points
The **loss** function should follow the following pattern:
def my_beautiful_loss(logits, labels):
return loss_set_of_ops(logits, labels)
**Parameters**
architecture:
Pointer to a function that builds the graph.
The signature should be something like `my_beautiful_function(input)`
optimizer:
One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
- tf.train.GradientDescentOptimizer
- tf.train.AdagradOptimizer
- ....
config:
n_classes:
Number of classes of your problem
Number of classes of your problem. The logits will be appended in this class
loss_op:
Pointer to a function that computes the loss.
The signature should be something like `my_beautiful_loss(logits, labels)`
embedding_validation:
Run the validation using embeddings?? [default: False]
......@@ -81,7 +98,7 @@ class LogitsTrainer(estimator.Estimator):
if self.loss_op is None:
raise ValueError("Please specify a function to build the loss !!")
if self.n_classes <=0:
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
......@@ -109,7 +126,7 @@ class LogitsTrainer(estimator.Estimator):
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Calculate Loss (for both TRAIN and EVAL modes)
# Compute Loss (for both TRAIN and EVAL modes)
self.loss = self.loss_op(logits, labels)
# Configure the Training Op (for TRAIN mode)
......@@ -136,3 +153,134 @@ class LogitsTrainer(estimator.Estimator):
model_dir=model_dir,
config=config)
class LogitsCenterLossTrainer(estimator.Estimator):
"""
NN Trainer whose with logits as last layer
The **architecture** function should follow the following pattern:
def my_beautiful_function(placeholder):
end_points = dict()
graph = convXX(placeholder)
end_points['conv'] = graph
....
return graph, end_points
**Parameters**
architecture:
Pointer to a function that builds the graph.
optimizer:
One of the tensorflow solvers (https://www.tensorflow.org/api_guides/python/train)
- tf.train.GradientDescentOptimizer
- tf.train.AdagradOptimizer
- ....
config:
n_classes:
Number of classes of your problem. The logits will be appended in this class
loss_op:
Pointer to a function that computes the loss.
embedding_validation:
Run the validation using embeddings?? [default: False]
model_dir:
Model path
validation_batch_size:
Size of the batch for validation. This value is used when the
validation with embeddings is used. This is a hack.
"""
def __init__(self,
architecture=None,
optimizer=None,
config=None,
n_classes=0,
embedding_validation=False,
model_dir="",
alpha=0.9,
factor=0.01,
validation_batch_size=None,
):
self.architecture = architecture
self.optimizer = optimizer
self.n_classes = n_classes
self.alpha = alpha
self.factor = factor
self.loss = None
self.embedding_validation = embedding_validation
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
if self.optimizer is None:
raise ValueError("Please specify a optimizer (https://www.tensorflow.org/api_guides/python/train) !!")
if self.n_classes <= 0:
raise ValueError("Number of classes must be greated than 0")
def _model_fn(features, labels, mode, params, config):
# Building one graph
prelogits = self.architecture(features)[0]
logits = append_logits(prelogits, n_classes)
if self.embedding_validation:
# Compute the embeddings
embeddings = tf.nn.l2_normalize(prelogits, 1)
predictions = {
"embeddings": embeddings
}
else:
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
# Add `softmax_tensor` to the graph. It is used for PREDICT and by the
# `logging_hook`.
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Compute Loss (for both TRAIN and EVAL modes)
loss_dict = mean_cross_entropy_center_loss(logits, prelogits, labels, self.n_classes,
alpha=self.alpha, factor=self.factor)
self.loss = loss_dict['loss']
centers = loss_dict['centers']
# Configure the Training Op (for TRAIN mode)
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.contrib.framework.get_or_create_global_step()
# backprop and updating the centers
train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step),
centers)
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss,
train_op=train_op)
if self.embedding_validation:
predictions_op = predict_using_tensors(predictions["embeddings"], labels, num=validation_batch_size)
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions_op)}
return tf.estimator.EstimatorSpec(mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
else:
# Add evaluation metrics (for EVAL mode)
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(
mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
super(LogitsCenterLossTrainer, self).__init__(model_fn=_model_fn,
model_dir=model_dir,
config=config)
......@@ -3,7 +3,7 @@ from .Trainer import Trainer
from .SiameseTrainer import SiameseTrainer
from .TripletTrainer import TripletTrainer
from .learning_rate import exponential_decay, constant
from .LogitsTrainer import LogitsTrainer
from .LogitsTrainer import LogitsTrainer, LogitsCenterLossTrainer
import numpy
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment