Skip to content
Snippets Groups Projects
Commit 35180690 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented center loss

parent a3e44720
No related branches found
No related tags found
1 merge request!17Updates
......@@ -61,7 +61,7 @@ class MeanSoftMaxLossCenterLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss", add_regularization_losses=True, alpha=0.9, factor=0.01, n_classes=10):
def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10):
"""
Constructor
......@@ -73,46 +73,36 @@ class MeanSoftMaxLossCenterLoss(object):
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
self.n_classes = n_classes
self.alpha = alpha
self.factor = factor
def append_center_loss(self, features, label):
nrof_features = features.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, nrof_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.reduce_mean(tf.square(features - centers_batch))
return loss
def __call__(self, logits_prelogits, label):
#TODO: Test the dictionary
logits = logits_prelogits['logits']
def __call__(self, logits, prelogits, label):
# Cross entropy
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
with tf.variable_scope('cross_entropy_loss'):
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
# Appending center loss
prelogits = logits_prelogits['prelogits']
center_loss = self.append_center_loss(prelogits, label)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Appending center loss
with tf.variable_scope('center_loss'):
n_features = prelogits.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - prelogits)
centers = tf.scatter_sub(centers, label, diff)
center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Adding the regularizers in the loss
if self.add_regularization_losses:
with tf.variable_scope('total_loss'):
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([loss] + regularization_losses, name='total_loss')
total_loss = tf.add_n([loss] + regularization_losses, name='total_loss')
return loss
return total_loss, centers
......@@ -219,7 +219,6 @@ class SiameseTrainer(Trainer):
return feed_dict
def fit(self, step):
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, bt_class, wt_class, lr, summary = self.session.run([
self.optimizer,
......
......@@ -177,7 +177,7 @@ class Trainer(object):
self.compute_validation(step)
# Taking snapshot
if step % self.snapshot == 0:
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.saver.save(self.session, path, global_step=step)
......@@ -214,6 +214,7 @@ class Trainer(object):
# Learning rate
learning_rate=None,
prelogits=None
):
"""
......@@ -229,7 +230,6 @@ class Trainer(object):
learning_rate: Learning rate
"""
# Getting the pointer to the placeholders
self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True)
......@@ -237,8 +237,13 @@ class Trainer(object):
self.graph = graph
self.loss = loss
# Attaching the loss in the graph
self.predictor = self.loss(self.graph, self.label_ph)
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
self.centers = None
if prelogits is not None:
tf.add_to_collection("prelogits", prelogits)
self.predictor, self.centers = self.loss(self.graph, prelogits, self.label_ph)
else:
self.predictor = self.loss(self.graph, self.label_ph)
self.optimizer_class = optimizer
self.learning_rate = learning_rate
......@@ -257,11 +262,8 @@ class Trainer(object):
# SAving some variables
tf.add_to_collection("global_step", self.global_step)
if isinstance(self.graph, dict):
tf.add_to_collection("graph", self.graph['logits'])
tf.add_to_collection("prelogits", self.graph['prelogits'])
else:
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("predictor", self.predictor)
......@@ -273,6 +275,10 @@ class Trainer(object):
tf.add_to_collection("summaries_train", self.summaries_train)
# Appending histograms for each trainable variables
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var)
# Same business with the validation
if self.validation_data_shuffler is not None:
self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
......@@ -280,9 +286,9 @@ class Trainer(object):
self.validation_graph = validation_graph
if self.validate_with_embeddings:
if self.validate_with_embeddings:
self.validation_predictor = self.validation_graph
else:
else:
self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
......@@ -318,13 +324,13 @@ class Trainer(object):
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def load_variables_from_external_model(self, file_name, var_list):
def load_variables_from_external_model(self, checkpoint_path, var_list):
"""
Load a set of variables from a given model and update them in the current one
** Parameters **
file_name:
checkpoint_path:
Name of the tensorflow model to be loaded
var_list:
List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
......@@ -338,7 +344,7 @@ class Trainer(object):
tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)
saver = tf.train.Saver(tf_varlist)
saver.restore(self.session, file_name)
saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_path))
def create_network_from_file(self, file_name, clear_devices=True):
"""
......@@ -406,8 +412,14 @@ class Trainer(object):
"""
if self.train_data_shuffler.prefetch:
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train])
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
if self.centers is None:
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train])
else:
_, l, lr, summary, _ = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train, self.centers])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
......@@ -473,10 +485,7 @@ class Trainer(object):
tf.summary.scalar('lr', self.learning_rate)
# Computing accuracy
if isinstance(output, dict):
correct_prediction = tf.equal(tf.argmax(output['logits'], 1), label)
else:
correct_prediction = tf.equal(tf.argmax(output, 1), label)
correct_prediction = tf.equal(tf.argmax(output, 1), label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment