From 35180690b232bcb116db9ff949a70ce391b1e3ca Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 5 Oct 2017 15:18:13 +0200 Subject: [PATCH] Implemented center loss --- bob/learn/tensorflow/loss/BaseLoss.py | 54 ++++++++----------- .../tensorflow/trainers/SiameseTrainer.py | 1 - bob/learn/tensorflow/trainers/Trainer.py | 49 ++++++++++------- 3 files changed, 51 insertions(+), 53 deletions(-) diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py index e71f0b73..8c27710d 100644 --- a/bob/learn/tensorflow/loss/BaseLoss.py +++ b/bob/learn/tensorflow/loss/BaseLoss.py @@ -61,7 +61,7 @@ class MeanSoftMaxLossCenterLoss(object): Mean softmax loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits. """ - def __init__(self, name="loss", add_regularization_losses=True, alpha=0.9, factor=0.01, n_classes=10): + def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10): """ Constructor @@ -73,46 +73,36 @@ class MeanSoftMaxLossCenterLoss(object): """ self.name = name - self.add_regularization_losses = add_regularization_losses self.n_classes = n_classes self.alpha = alpha self.factor = factor - def append_center_loss(self, features, label): - nrof_features = features.get_shape()[1] - - centers = tf.get_variable('centers', [self.n_classes, nrof_features], dtype=tf.float32, - initializer=tf.constant_initializer(0), trainable=False) - - label = tf.reshape(label, [-1]) - centers_batch = tf.gather(centers, label) - diff = (1 - self.alpha) * (centers_batch - features) - centers = tf.scatter_sub(centers, label, diff) - loss = tf.reduce_mean(tf.square(features - centers_batch)) - - return loss - - - def __call__(self, logits_prelogits, label): - - #TODO: Test the dictionary - - logits = logits_prelogits['logits'] - + def __call__(self, logits, prelogits, label): # Cross entropy - loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=label), name=self.name) + with tf.variable_scope('cross_entropy_loss'): + loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=label), name=self.name) - # Appending center loss - prelogits = logits_prelogits['prelogits'] - center_loss = self.append_center_loss(prelogits, label) - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor) + # Appending center loss + with tf.variable_scope('center_loss'): + n_features = prelogits.get_shape()[1] + + centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32, + initializer=tf.constant_initializer(0), trainable=False) + + label = tf.reshape(label, [-1]) + centers_batch = tf.gather(centers, label) + diff = (1 - self.alpha) * (centers_batch - prelogits) + centers = tf.scatter_sub(centers, label, diff) + center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch)) + tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor) # Adding the regularizers in the loss - if self.add_regularization_losses: + with tf.variable_scope('total_loss'): regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) - loss = tf.add_n([loss] + regularization_losses, name='total_loss') + total_loss = tf.add_n([loss] + regularization_losses, name='total_loss') - return loss + return total_loss, centers + diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index be16b4fe..300b8a06 100644 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -219,7 +219,6 @@ class SiameseTrainer(Trainer): return feed_dict def fit(self, step): - feed_dict = self.get_feed_dict(self.train_data_shuffler) _, l, bt_class, wt_class, lr, summary = self.session.run([ self.optimizer, diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index 6631ca7b..8b7ebc2e 100644 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -177,7 +177,7 @@ class Trainer(object): self.compute_validation(step) # Taking snapshot - if step % self.snapshot == 0: + if step % self.snapshot == 0: logger.info("Taking snapshot") path = os.path.join(self.temp_dir, 'model_snapshot{0}.ckp'.format(step)) self.saver.save(self.session, path, global_step=step) @@ -214,6 +214,7 @@ class Trainer(object): # Learning rate learning_rate=None, + prelogits=None ): """ @@ -229,7 +230,6 @@ class Trainer(object): learning_rate: Learning rate """ - # Getting the pointer to the placeholders self.data_ph = self.train_data_shuffler("data", from_queue=True) self.label_ph = self.train_data_shuffler("label", from_queue=True) @@ -237,8 +237,13 @@ class Trainer(object): self.graph = graph self.loss = loss - # Attaching the loss in the graph - self.predictor = self.loss(self.graph, self.label_ph) + # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT + self.centers = None + if prelogits is not None: + tf.add_to_collection("prelogits", prelogits) + self.predictor, self.centers = self.loss(self.graph, prelogits, self.label_ph) + else: + self.predictor = self.loss(self.graph, self.label_ph) self.optimizer_class = optimizer self.learning_rate = learning_rate @@ -257,11 +262,8 @@ class Trainer(object): # SAving some variables tf.add_to_collection("global_step", self.global_step) - if isinstance(self.graph, dict): - tf.add_to_collection("graph", self.graph['logits']) - tf.add_to_collection("prelogits", self.graph['prelogits']) - else: - tf.add_to_collection("graph", self.graph) + + tf.add_to_collection("graph", self.graph) tf.add_to_collection("predictor", self.predictor) @@ -273,6 +275,10 @@ class Trainer(object): tf.add_to_collection("summaries_train", self.summaries_train) + # Appending histograms for each trainable variables + for var in tf.trainable_variables(): + tf.summary.histogram(var.op.name, var) + # Same business with the validation if self.validation_data_shuffler is not None: self.validation_data_ph = self.validation_data_shuffler("data", from_queue=True) @@ -280,9 +286,9 @@ class Trainer(object): self.validation_graph = validation_graph - if self.validate_with_embeddings: + if self.validate_with_embeddings: self.validation_predictor = self.validation_graph - else: + else: self.validation_predictor = self.loss(self.validation_graph, self.validation_label_ph) self.summaries_validation = self.create_general_summary(self.validation_predictor, self.validation_graph, self.validation_label_ph) @@ -318,13 +324,13 @@ class Trainer(object): self.saver = tf.train.import_meta_graph(file_name, clear_devices=clear_devices) self.saver.restore(self.session, tf.train.latest_checkpoint(os.path.dirname(file_name))) - def load_variables_from_external_model(self, file_name, var_list): + def load_variables_from_external_model(self, checkpoint_path, var_list): """ Load a set of variables from a given model and update them in the current one ** Parameters ** - file_name: + checkpoint_path: Name of the tensorflow model to be loaded var_list: List of variables to be loaded. A tensorflow exception will be raised in case the variable does not exists @@ -338,7 +344,7 @@ class Trainer(object): tf_varlist += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=v) saver = tf.train.Saver(tf_varlist) - saver.restore(self.session, file_name) + saver.restore(self.session, tf.train.latest_checkpoint(checkpoint_path)) def create_network_from_file(self, file_name, clear_devices=True): """ @@ -406,8 +412,14 @@ class Trainer(object): """ if self.train_data_shuffler.prefetch: - _, l, lr, summary = self.session.run([self.optimizer, self.predictor, - self.learning_rate, self.summaries_train]) + # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT + if self.centers is None: + _, l, lr, summary = self.session.run([self.optimizer, self.predictor, + self.learning_rate, self.summaries_train]) + else: + _, l, lr, summary, _ = self.session.run([self.optimizer, self.predictor, + self.learning_rate, self.summaries_train, self.centers]) + else: feed_dict = self.get_feed_dict(self.train_data_shuffler) _, l, lr, summary = self.session.run([self.optimizer, self.predictor, @@ -473,10 +485,7 @@ class Trainer(object): tf.summary.scalar('lr', self.learning_rate) # Computing accuracy - if isinstance(output, dict): - correct_prediction = tf.equal(tf.argmax(output['logits'], 1), label) - else: - correct_prediction = tf.equal(tf.argmax(output, 1), label) + correct_prediction = tf.equal(tf.argmax(output, 1), label) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) -- GitLab