Implemented center loss

parent a3e44720
......@@ -61,7 +61,7 @@ class MeanSoftMaxLossCenterLoss(object):
Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits.
"""
def __init__(self, name="loss", add_regularization_losses=True, alpha=0.9, factor=0.01, n_classes=10):
def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10):
"""
Constructor
......@@ -73,46 +73,36 @@ class MeanSoftMaxLossCenterLoss(object):
"""
self.name = name
self.add_regularization_losses = add_regularization_losses
self.n_classes = n_classes
self.alpha = alpha
self.factor = factor
def append_center_loss(self, features, label):
nrof_features = features.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, nrof_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - features)
centers = tf.scatter_sub(centers, label, diff)
loss = tf.reduce_mean(tf.square(features - centers_batch))
return loss
def __call__(self, logits_prelogits, label):
#TODO: Test the dictionary
logits = logits_prelogits['logits']
def __call__(self, logits, prelogits, label):
# Cross entropy
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
with tf.variable_scope('cross_entropy_loss'):
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label), name=self.name)
# Appending center loss
prelogits = logits_prelogits['prelogits']
center_loss = self.append_center_loss(prelogits, label)
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Appending center loss
with tf.variable_scope('center_loss'):
n_features = prelogits.get_shape()[1]
centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label = tf.reshape(label, [-1])
centers_batch = tf.gather(centers, label)
diff = (1 - self.alpha) * (centers_batch - prelogits)
centers = tf.scatter_sub(centers, label, diff)
center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor)
# Adding the regularizers in the loss
if self.add_regularization_losses:
with tf.variable_scope('total_loss'):
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([loss] + regularization_losses, name='total_loss')
total_loss = tf.add_n([loss] + regularization_losses, name='total_loss')
return loss
return total_loss, centers
......@@ -219,7 +219,6 @@ class SiameseTrainer(Trainer):
return feed_dict
def fit(self, step):
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, bt_class, wt_class, lr, summary = self.session.run([
self.optimizer,
......
......@@ -177,7 +177,7 @@ class Trainer(object):
self.compute_validation(step)
# Taking snapshot
if step % self.snapshot == 0:
if step % self.snapshot == 0:
logger.info("Taking snapshot")
path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step))
self.saver.save(self.session, path, global_step=step)
......@@ -214,6 +214,7 @@ class Trainer(object):
# Learning rate
learning_rate=None,
prelogits=None
):
"""
......@@ -229,7 +230,6 @@ class Trainer(object):
learning_rate: Learning rate
"""
# Getting the pointer to the placeholders
self.data_ph = self.train_data_shuffler("data", from_queue=True)
self.label_ph = self.train_data_shuffler("label", from_queue=True)
......@@ -237,8 +237,13 @@ class Trainer(object):
self.graph = graph
self.loss = loss
# Attaching the loss in the graph
self.predictor = self.loss(self.graph, self.label_ph)
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
self.centers = None
if prelogits is not None:
tf.add_to_collection("prelogits", prelogits)
self.predictor, self.centers = self.loss(self.graph, prelogits, self.label_ph)
else:
self.predictor = self.loss(self.graph, self.label_ph)
self.optimizer_class = optimizer
self.learning_rate = learning_rate
......@@ -257,11 +262,8 @@ class Trainer(object):
# SAving some variables
tf.add_to_collection("global_step", self.global_step)
if isinstance(self.graph, dict):
tf.add_to_collection("graph", self.graph['logits'])
tf.add_to_collection("prelogits", self.graph['prelogits'])
else:
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("graph", self.graph)
tf.add_to_collection("predictor", self.predictor)
......@@ -273,6 +275,10 @@ class Trainer(object):
tf.add_to_collection("summaries_train", self.summaries_train)
# Appending histograms for each trainable variables
for var in tf.trainable_variables():
tf.summary.histogram(var.op.name, var)
# Same business with the validation
if self.validation_data_shuffler is not None:
self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True)
......@@ -280,9 +286,9 @@ class Trainer(object):
self.validation_graph = validation_graph
if self.validate_with_embeddings:
if self.validate_with_embeddings:
self.validation_predictor = self.validation_graph
else:
else:
self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph)
self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph)
......@@ -318,13 +324,13 @@ class Trainer(object):
self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices)
self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name)))
def load_variables_from_external_model(self, file_name, var_list):
def load_variables_from_external_model(self, checkpoint_path, var_list):
"""
Load a set of variables from a given model and update them in the current one
** Parameters **
file_name:
checkpoint_path:
Name of the tensorflow model to be loaded
var_list:
List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists
......@@ -338,7 +344,7 @@ class Trainer(object):
tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v)
saver = tf.train.Saver(tf_varlist)
saver.restore(self.session, file_name)
saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_path))
def create_network_from_file(self, file_name, clear_devices=True):
"""
......@@ -406,8 +412,14 @@ class Trainer(object):
"""
if self.train_data_shuffler.prefetch:
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train])
# TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT
if self.centers is None:
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train])
else:
_, l, lr, summary, _ = self.session.run([self.optimizer, self.predictor,
self.learning_rate, self.summaries_train, self.centers])
else:
feed_dict = self.get_feed_dict(self.train_data_shuffler)
_, l, lr, summary = self.session.run([self.optimizer, self.predictor,
......@@ -473,10 +485,7 @@ class Trainer(object):
tf.summary.scalar('lr', self.learning_rate)
# Computing accuracy
if isinstance(output, dict):
correct_prediction = tf.equal(tf.argmax(output['logits'], 1), label)
else:
correct_prediction = tf.equal(tf.argmax(output, 1), label)
correct_prediction = tf.equal(tf.argmax(output, 1), label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment