From ad6a9bba6c6243c7d5a88e0574a65b458ff06717 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 11 Oct 2017 17:58:04 +0200 Subject: [PATCH] Replaced some classes to functions #37 . Still need to update some tests --- bob/learn/tensorflow/loss/BaseLoss.py | 101 ++++++-------- bob/learn/tensorflow/loss/ContrastiveLoss.py | 45 +++---- .../tensorflow/loss/TripletAverageLoss.py | 60 --------- .../tensorflow/loss/TripletFisherLoss.py | 62 --------- bob/learn/tensorflow/loss/TripletLoss.py | 122 ++++++++++++++--- bob/learn/tensorflow/loss/__init__.py | 20 +-- bob/learn/tensorflow/network/Chopra.py | 123 +++++++----------- bob/learn/tensorflow/network/__init__.py | 2 +- bob/learn/tensorflow/network/utils.py | 11 +- .../tensorflow/script/lfw_db_to_tfrecords.py | 10 +- bob/learn/tensorflow/script/train.py | 8 +- bob/learn/tensorflow/test/test_cnn.py | 79 ++++++----- .../tensorflow/test/test_cnn_other_losses.py | 50 ++----- bob/learn/tensorflow/test/test_cnn_scratch.py | 2 +- .../tensorflow/trainers/SiameseTrainer.py | 33 ++--- bob/learn/tensorflow/trainers/Trainer.py | 26 ++-- .../tensorflow/trainers/TripletTrainer.py | 30 ++--- 17 files changed, 315 insertions(+), 469 deletions(-) delete mode 100755 bob/learn/tensorflow/loss/TripletAverageLoss.py delete mode 100755 bob/learn/tensorflow/loss/TripletFisherLoss.py diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py index 679dd12b..cbfeadfa 100755 --- a/bob/learn/tensorflow/loss/BaseLoss.py +++ b/bob/learn/tensorflow/loss/BaseLoss.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Tue 09 Aug 2016 16:38 CEST import logging import tensorflow as tf @@ -10,92 +9,70 @@ logger = logging.getLogger("bob.learn.tensorflow") slim = tf.contrib.slim -class BaseLoss(object): - """ - Base loss function. - Stupid class. Don't know why I did that. - """ - - def __init__(self, loss, operation, name="loss"): - self.loss = loss - self.operation = operation - self.name = name - - def __call__(self, graph, label): - return self.operation(self.loss(logits=graph, labels=label), name=self.name) - - -class MeanSoftMaxLoss(object): +def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True): """ Simple CrossEntropy loss. Basically it wrapps the function tf.nn.sparse_softmax_cross_entropy_with_logits. **Parameters** - - name: Scope name + logits: + labels: add_regularization_losses: Regulize the loss??? """ - def __init__(self, name="loss", add_regularization_losses=True): - self.name = name - self.add_regularization_losses = add_regularization_losses + with tf.variable_scope('cross_entropy_loss'): - def __call__(self, graph, label): - loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=graph, labels=label), name=self.name) - - if self.add_regularization_losses: + logits=logits, labels=labels), name=tf.GraphKeys.LOSSES) + + if add_regularization_losses: regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) - return tf.add_n([loss] + regularization_losses, name='total_loss') + return tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES) else: return loss -class MeanSoftMaxLossCenterLoss(object): +def mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes, alpha=0.9, factor=0.01): """ Implementation of the CrossEntropy + Center Loss from the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"(http://ydwen.github.io/papers/WenECCV16.pdf) **Parameters** - - name: Scope name + logits: + prelogits: + labels: + n_classes: Number of classes of your task alpha: Alpha factor ((1-alpha)*centers-prelogits) factor: Weight factor of the center loss - n_classes: Number of classes of your task + """ - def __init__(self, name="loss", alpha=0.9, factor=0.01, n_classes=10): - self.name = name + # Cross entropy + with tf.variable_scope('cross_entropy_loss'): + loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels), name=tf.GraphKeys.LOSSES) - self.n_classes = n_classes - self.alpha = alpha - self.factor = factor + # Appending center loss + with tf.variable_scope('center_loss'): + n_features = prelogits.get_shape()[1] + + centers = tf.get_variable('centers', [n_classes, n_features], dtype=tf.float32, + initializer=tf.constant_initializer(0), trainable=False) + + label = tf.reshape(labels, [-1]) + centers_batch = tf.gather(centers, labels) + diff = (1 - alpha) * (centers_batch - prelogits) + centers = tf.scatter_sub(centers, labels, diff) + center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch)) + tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * factor) + # Adding the regularizers in the loss + with tf.variable_scope('total_loss'): + regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + total_loss = tf.add_n([loss] + regularization_losses, name=tf.GraphKeys.LOSSES) - def __call__(self, logits, prelogits, label): - # Cross entropy - with tf.variable_scope('cross_entropy_loss'): - loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=label), name=self.name) + loss = dict() + loss['loss'] = total_loss + loss['centers'] = centers - # Appending center loss - with tf.variable_scope('center_loss'): - n_features = prelogits.get_shape()[1] - - centers = tf.get_variable('centers', [self.n_classes, n_features], dtype=tf.float32, - initializer=tf.constant_initializer(0), trainable=False) - - label = tf.reshape(label, [-1]) - centers_batch = tf.gather(centers, label) - diff = (1 - self.alpha) * (centers_batch - prelogits) - centers = tf.scatter_sub(centers, label, diff) - center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch)) - tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, center_loss * self.factor) - - # Adding the regularizers in the loss - with tf.variable_scope('total_loss'): - regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) - total_loss = tf.add_n([loss] + regularization_losses, name='total_loss') - - return total_loss, centers + return loss diff --git a/bob/learn/tensorflow/loss/ContrastiveLoss.py b/bob/learn/tensorflow/loss/ContrastiveLoss.py index 4c25a981..1ec9ace5 100755 --- a/bob/learn/tensorflow/loss/ContrastiveLoss.py +++ b/bob/learn/tensorflow/loss/ContrastiveLoss.py @@ -1,17 +1,15 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Wed 10 Aug 2016 16:38 CEST import logging logger = logging.getLogger("bob.learn.tensorflow") import tensorflow as tf -from .BaseLoss import BaseLoss from bob.learn.tensorflow.utils import compute_euclidean_distance -class ContrastiveLoss(BaseLoss): +def contrastive_loss(left_embedding, right_embedding, labels, contrastive_margin=1.0): """ Compute the contrastive loss as in @@ -27,7 +25,7 @@ class ContrastiveLoss(BaseLoss): right_feature: Second element of the pair - label: + labels: Label of the pair (0 or 1) margin: @@ -35,30 +33,25 @@ class ContrastiveLoss(BaseLoss): """ - def __init__(self, contrastive_margin=1.0): - self.contrastive_margin = contrastive_margin + with tf.name_scope("contrastive_loss"): + labels = tf.to_float(labels) + + left_embedding = tf.nn.l2_normalize(left_embedding, 1) + right_embedding = tf.nn.l2_normalize(right_embedding, 1) - def __call__(self, label, left_feature, right_feature): - with tf.name_scope("contrastive_loss"): - label = tf.to_float(label) - - left_feature = tf.nn.l2_normalize(left_feature, 1) - right_feature = tf.nn.l2_normalize(right_feature, 1) + one = tf.constant(1.0) - one = tf.constant(1.0) + d = compute_euclidean_distance(left_embedding, right_embedding) + within_class = tf.multiply(one - labels, tf.square(d)) # (1-Y)*(d^2) + + max_part = tf.square(tf.maximum(contrastive_margin - d, 0)) + between_class = tf.multiply(labels, max_part) # (Y) * max((margin - d)^2, 0) - d = compute_euclidean_distance(left_feature, right_feature) - within_class = tf.multiply(one - label, tf.square(d)) # (1-Y)*(d^2) - - - max_part = tf.square(tf.maximum(self.contrastive_margin - d, 0)) - between_class = tf.multiply(label, max_part) # (Y) * max((margin - d)^2, 0) + loss = 0.5 * (within_class + between_class) - loss = 0.5 * (within_class + between_class) + loss_dict = dict() + loss_dict['loss'] = tf.reduce_mean(loss, name=tf.GraphKeys.LOSSES) + loss_dict['between_class'] = tf.reduce_mean(between_class, name=tf.GraphKeys.LOSSES) + loss_dict['within_class'] = tf.reduce_mean(within_class, name=tf.GraphKeys.LOSSES) - loss_dict = dict() - loss_dict['loss'] = tf.reduce_mean(loss) - loss_dict['between_class'] = tf.reduce_mean(between_class) - loss_dict['within_class'] = tf.reduce_mean(within_class) - - return loss_dict + return loss_dict diff --git a/bob/learn/tensorflow/loss/TripletAverageLoss.py b/bob/learn/tensorflow/loss/TripletAverageLoss.py deleted file mode 100755 index bcb7bea8..00000000 --- a/bob/learn/tensorflow/loss/TripletAverageLoss.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Wed 10 Aug 2016 16:38 CEST - -import logging -logger = logging.getLogger("bob.learn.tensorflow") -import tensorflow as tf - -from .BaseLoss import BaseLoss -from bob.learn.tensorflow.utils import compute_euclidean_distance - - -class TripletAverageLoss(BaseLoss): - """ - Compute the triplet loss as in - - Schroff, Florian, Dmitry Kalenichenko, and James Philbin. - "Facenet: A unified embedding for face recognition and clustering." - Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015. - - :math:`L = sum( |f_a - f_p|^2 - |f_a - f_n|^2 + \lambda)` - - **Parameters** - - left_feature: - First element of the pair - - right_feature: - Second element of the pair - - label: - Label of the pair (0 or 1) - - margin: - Contrastive margin - - """ - - def __init__(self, margin=0.1): - self.margin = margin - - def __call__(self, anchor_embedding, positive_embedding, negative_embedding): - - with tf.name_scope("triplet_loss"): - # Normalize - anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor") - positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive") - negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative") - - anchor_mean = tf.reduce_mean(anchor_embedding, 0) - - d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, positive_embedding)), 1) - d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, negative_embedding)), 1) - - basic_loss = tf.add(tf.subtract(d_positive, d_negative), self.margin) - loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0) - - return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive) - diff --git a/bob/learn/tensorflow/loss/TripletFisherLoss.py b/bob/learn/tensorflow/loss/TripletFisherLoss.py deleted file mode 100755 index 54c0ad02..00000000 --- a/bob/learn/tensorflow/loss/TripletFisherLoss.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Wed 10 Aug 2016 16:38 CEST - -import logging -logger = logging.getLogger("bob.learn.tensorflow") -import tensorflow as tf - -from .BaseLoss import BaseLoss -from bob.learn.tensorflow.utils import compute_euclidean_distance - - -class TripletFisherLoss(BaseLoss): - """ - """ - - def __init__(self): - pass - - def __call__(self, anchor_embedding, positive_embedding, negative_embedding): - - with tf.name_scope("triplet_loss"): - # Normalize - anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor") - positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive") - negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative") - - average_class = tf.reduce_mean(anchor_embedding, 0) - average_total = tf.div(tf.add(tf.reduce_mean(anchor_embedding, axis=0),\ - tf.reduce_mean(negative_embedding, axis=0)), 2) - - length = anchor_embedding.get_shape().as_list()[0] - dim = anchor_embedding.get_shape().as_list()[1] - split_positive = tf.unstack(positive_embedding, num=length, axis=0) - split_negative = tf.unstack(negative_embedding, num=length, axis=0) - - Sw = None - Sb = None - for s in zip(split_positive, split_negative): - positive = s[0] - negative = s[1] - - buffer_sw = tf.reshape(tf.subtract(positive, average_class), shape=(dim, 1)) - buffer_sw = tf.matmul(buffer_sw, tf.reshape(buffer_sw, shape=(1, dim))) - - buffer_sb = tf.reshape(tf.subtract(negative, average_total), shape=(dim, 1)) - buffer_sb = tf.matmul(buffer_sb, tf.reshape(buffer_sb, shape=(1, dim))) - - if Sw is None: - Sw = buffer_sw - Sb = buffer_sb - else: - Sw = tf.add(Sw, buffer_sw) - Sb = tf.add(Sb, buffer_sb) - - # Sw = tf.trace(Sw) - # Sb = tf.trace(Sb) - #loss = tf.trace(tf.div(Sb, Sw)) - loss = tf.trace(tf.div(Sw, Sb)) - - return loss, tf.trace(Sb), tf.trace(Sw) diff --git a/bob/learn/tensorflow/loss/TripletLoss.py b/bob/learn/tensorflow/loss/TripletLoss.py index 4478a12d..c642507c 100755 --- a/bob/learn/tensorflow/loss/TripletLoss.py +++ b/bob/learn/tensorflow/loss/TripletLoss.py @@ -1,17 +1,15 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Wed 10 Aug 2016 16:38 CEST import logging logger = logging.getLogger("bob.learn.tensorflow") import tensorflow as tf -from .BaseLoss import BaseLoss from bob.learn.tensorflow.utils import compute_euclidean_distance -class TripletLoss(BaseLoss): +def triplet_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0): """ Compute the triplet loss as in @@ -37,26 +35,110 @@ class TripletLoss(BaseLoss): """ - def __init__(self, margin=5.0): - self.margin = margin + with tf.name_scope("triplet_loss"): + # Normalize + anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor") + positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive") + negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative") + + d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, positive_embedding)), 1) + d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1) + + basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin) + loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES) + + loss_dict = dict() + loss_dict['loss'] = loss + loss_dict['between_class'] = tf.reduce_mean(d_negative) + loss_dict['within_class'] = tf.reduce_mean(d_positive) + + return loss_dict + + +def triplet_fisher_loss(anchor_embedding, positive_embedding, negative_embedding): + + with tf.name_scope("triplet_loss"): + # Normalize + anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor") + positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive") + negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative") + + average_class = tf.reduce_mean(anchor_embedding, 0) + average_total = tf.div(tf.add(tf.reduce_mean(anchor_embedding, axis=0),\ + tf.reduce_mean(negative_embedding, axis=0)), 2) + + length = anchor_embedding.get_shape().as_list()[0] + dim = anchor_embedding.get_shape().as_list()[1] + split_positive = tf.unstack(positive_embedding, num=length, axis=0) + split_negative = tf.unstack(negative_embedding, num=length, axis=0) + + Sw = None + Sb = None + for s in zip(split_positive, split_negative): + positive = s[0] + negative = s[1] + + buffer_sw = tf.reshape(tf.subtract(positive, average_class), shape=(dim, 1)) + buffer_sw = tf.matmul(buffer_sw, tf.reshape(buffer_sw, shape=(1, dim))) + + buffer_sb = tf.reshape(tf.subtract(negative, average_total), shape=(dim, 1)) + buffer_sb = tf.matmul(buffer_sb, tf.reshape(buffer_sb, shape=(1, dim))) + + if Sw is None: + Sw = buffer_sw + Sb = buffer_sb + else: + Sw = tf.add(Sw, buffer_sw) + Sb = tf.add(Sb, buffer_sb) + + # Sw = tf.trace(Sw) + # Sb = tf.trace(Sb) + #loss = tf.trace(tf.div(Sb, Sw)) + loss = tf.trace(tf.div(Sw, Sb), name=tf.GraphKeys.LOSSES) + + return loss, tf.trace(Sb), tf.trace(Sw) + + +def triplet_average_loss(anchor_embedding, positive_embedding, negative_embedding, margin=5.0): + """ + Compute the triplet loss as in + + Schroff, Florian, Dmitry Kalenichenko, and James Philbin. + "Facenet: A unified embedding for face recognition and clustering." + Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015. + + :math:`L = sum( |f_a - f_p|^2 - |f_a - f_n|^2 + \lambda)` + + **Parameters** + + left_feature: + First element of the pair + + right_feature: + Second element of the pair + + label: + Label of the pair (0 or 1) + + margin: + Contrastive margin + + """ - def __call__(self, anchor_embedding, positive_embedding, negative_embedding): + with tf.name_scope("triplet_loss"): + # Normalize + anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor") + positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive") + negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative") - with tf.name_scope("triplet_loss"): - # Normalize - anchor_embedding = tf.nn.l2_normalize(anchor_embedding, 1, 1e-10, name="anchor") - positive_embedding = tf.nn.l2_normalize(positive_embedding, 1, 1e-10, name="positive") - negative_embedding = tf.nn.l2_normalize(negative_embedding, 1, 1e-10, name="negative") + anchor_mean = tf.reduce_mean(anchor_embedding, 0) - d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, positive_embedding)), 1) - d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_embedding, negative_embedding)), 1) + d_positive = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, positive_embedding)), 1) + d_negative = tf.reduce_sum(tf.square(tf.subtract(anchor_mean, negative_embedding)), 1) - basic_loss = tf.add(tf.subtract(d_positive, d_negative), self.margin) - loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0) + basic_loss = tf.add(tf.subtract(d_positive, d_negative), margin) + loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0, name=tf.GraphKeys.LOSSES) - loss_dict = dict() - loss_dict['loss'] = loss - loss_dict['between_class'] = tf.reduce_mean(d_negative) - loss_dict['within_class'] = tf.reduce_mean(d_positive) + return loss, tf.reduce_mean(d_negative), tf.reduce_mean(d_positive) - return loss_dict + diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py index 19e58a34..1cf46711 100755 --- a/bob/learn/tensorflow/loss/__init__.py +++ b/bob/learn/tensorflow/loss/__init__.py @@ -1,9 +1,7 @@ -from .BaseLoss import BaseLoss, MeanSoftMaxLoss, MeanSoftMaxLossCenterLoss -from .ContrastiveLoss import ContrastiveLoss -from .TripletLoss import TripletLoss -from .TripletAverageLoss import TripletAverageLoss -from .TripletFisherLoss import TripletFisherLoss -from .NegLogLoss import NegLogLoss +from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss +from .ContrastiveLoss import contrastive_loss +from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss +#from .NegLogLoss import NegLogLoss # gets sphinx autodoc done right - don't remove it @@ -21,13 +19,9 @@ def __appropriate__(*args): for obj in args: obj.__module__ = __name__ __appropriate__( - BaseLoss, - ContrastiveLoss, - TripletLoss, - TripletFisherLoss, - TripletAverageLoss, - NegLogLoss, - MeanSoftMaxLoss + mean_cross_entropy_loss, mean_cross_entropy_center_loss, + contrastive_loss, + triplet_loss, triplet_average_loss, triplet_fisher_loss ) __all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/learn/tensorflow/network/Chopra.py b/bob/learn/tensorflow/network/Chopra.py index e8933ad0..2a4328f9 100755 --- a/bob/learn/tensorflow/network/Chopra.py +++ b/bob/learn/tensorflow/network/Chopra.py @@ -1,13 +1,22 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# @date: Wed 11 May 2016 09:39:36 CEST import tensorflow as tf -from .utils import append_logits +def chopra(inputs, conv1_kernel_size=[7, 7], + conv1_output=15, -class Chopra(object): + pooling1_size=[2, 2], + + + conv2_kernel_size=[6, 6], + conv2_output=45, + + pooling2_size=[4, 3], + fc1_output=250, + seed=10, + reuse=False,): """Class that creates the architecture presented in the paper: Chopra, Sumit, Raia Hadsell, and Yann LeCun. "Learning a similarity metric discriminatively, with application to @@ -49,79 +58,41 @@ class Chopra(object): fc1_output: - n_classes: If None, no Fully COnnected layer with class output will be created - seed: """ - def __init__(self, - conv1_kernel_size=[7, 7], - conv1_output=15, - - pooling1_size=[2, 2], - - - conv2_kernel_size=[6, 6], - conv2_output=45, - - pooling2_size=[4, 3], - - fc1_output=250, - n_classes=None, - seed=10): - - self.conv1_kernel_size = conv1_kernel_size - self.conv1_output = conv1_output - self.pooling1_size = pooling1_size - - self.conv2_output = conv2_output - self.conv2_kernel_size = conv2_kernel_size - self.pooling2_size = pooling2_size - - self.fc1_output = fc1_output - - self.seed = seed - self.n_classes = n_classes - - - def __call__(self, inputs, reuse=False, end_point='logits'): - slim = tf.contrib.slim - - end_points = dict() - - initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=self.seed) - - graph = slim.conv2d(inputs, self.conv1_output, self.conv1_kernel_size, activation_fn=tf.nn.relu, - stride=1, - weights_initializer=initializer, - scope='conv1', - reuse=reuse) - end_points['conv1'] = graph - - graph = slim.max_pool2d(graph, self.pooling1_size, scope='pool1') - end_points['pool1'] = graph - - graph = slim.conv2d(graph, self.conv2_output, self.conv2_kernel_size, activation_fn=tf.nn.relu, - stride=1, - weights_initializer=initializer, - scope='conv2', reuse=reuse) - end_points['conv2'] = graph - graph = slim.max_pool2d(graph, self.pooling2_size, scope='pool2') - end_points['pool2'] = graph - - graph = slim.flatten(graph, scope='flatten1') - end_points['flatten1'] = graph - - graph = slim.fully_connected(graph, self.fc1_output, - weights_initializer=initializer, - activation_fn=None, - scope='fc1', - reuse=reuse) - end_points['fc1'] = graph - - if self.n_classes is not None: - # Appending the logits layer - graph = append_logits(graph, self.n_classes, reuse) - end_points['logits'] = graph - - return end_points[end_point] + slim = tf.contrib.slim + + end_points = dict() + + initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed) + + graph = slim.conv2d(inputs, conv1_output, conv1_kernel_size, activation_fn=tf.nn.relu, + stride=1, + weights_initializer=initializer, + scope='conv1', + reuse=reuse) + end_points['conv1'] = graph + + graph = slim.max_pool2d(graph, pooling1_size, scope='pool1') + end_points['pool1'] = graph + + graph = slim.conv2d(graph, conv2_output, conv2_kernel_size, activation_fn=tf.nn.relu, + stride=1, + weights_initializer=initializer, + scope='conv2', reuse=reuse) + end_points['conv2'] = graph + graph = slim.max_pool2d(graph, pooling2_size, scope='pool2') + end_points['pool2'] = graph + + graph = slim.flatten(graph, scope='flatten1') + end_points['flatten1'] = graph + + graph = slim.fully_connected(graph, fc1_output, + weights_initializer=initializer, + activation_fn=None, + scope='fc1', + reuse=reuse) + end_points['fc1'] = graph + + return graph, end_points diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index f0997937..68ed993e 100755 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -1,4 +1,4 @@ -from .Chopra import Chopra +from .Chopra import chopra from .LightCNN9 import LightCNN9 from .LightCNN29 import LightCNN29 from .Dummy import Dummy diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py index 8ce0ed8b..ceb35a54 100755 --- a/bob/learn/tensorflow/network/utils.py +++ b/bob/learn/tensorflow/network/utils.py @@ -6,12 +6,9 @@ import tensorflow as tf slim = tf.contrib.slim -def append_logits(graph, n_classes, reuse): - graph = slim.fully_connected(graph, n_classes, activation_fn=None, - weights_initializer=tf.truncated_normal_initializer(stddev=0.1), - weights_regularizer=slim.l2_regularizer(0.1), +def append_logits(graph, n_classes, reuse=False, l2_regularizer=0.001, weights_std=0.1): + return slim.fully_connected(graph, n_classes, activation_fn=None, + weights_initializer=tf.truncated_normal_initializer(stddev=weights_std), + weights_regularizer=slim.l2_regularizer(l2_regularizer), scope='Logits', reuse=reuse) - return graph - - diff --git a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py index 9447bf62..7999b635 100755 --- a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py @@ -84,18 +84,14 @@ def main(argv=None): create_directories_safe(os.path.dirname(output_file)) - import ipdb; ipdb.set_trace() - n_files = len(enroll) with tf.python_io.TFRecordWriter(output_file) as writer: for e, p, i in zip(enroll, probe, range(len(enroll)) ): logger.info('Processing pair %d out of %d', i + 1, n_files) - - e_path = e.make_path(data_path, extension) - p_path = p.make_path(data_path, extension) - if os.path.exists(p_path) and os.path.exists(e_path): - for path in [e_path, p_path]: + if os.path.exists(e.make_path(data_path, extension)) and os.path.exists(p.make_path(data_path, extension)): + for f in [e, p]: + path = f.make_path(data_path, extension) data = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype(data_type) data = data.tostring() diff --git a/bob/learn/tensorflow/script/train.py b/bob/learn/tensorflow/script/train.py index d177aca7..0c20262e 100755 --- a/bob/learn/tensorflow/script/train.py +++ b/bob/learn/tensorflow/script/train.py @@ -73,7 +73,7 @@ def main(): return True config = imp.load_source('config', args['<configuration>']) - + # Cleaning all variables in case you are loading the checkpoint tf.reset_default_graph() if os.path.exists(output_dir) else None @@ -107,9 +107,9 @@ def main(): train_graph = None validation_graph = None validate_with_embeddings = False - - if hasattr(config, 'train_graph'): - train_graph = config.train_graph + + if hasattr(config, 'logits'): + train_graph = config.logits if hasattr(config, 'validation_graph'): validation_graph = config.validation_graph diff --git a/bob/learn/tensorflow/test/test_cnn.py b/bob/learn/tensorflow/test/test_cnn.py index e2b15776..86dfdfd0 100755 --- a/bob/learn/tensorflow/test/test_cnn.py +++ b/bob/learn/tensorflow/test/test_cnn.py @@ -5,11 +5,13 @@ import numpy from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, ImageAugmentation, ScaleFactor, Linear -from bob.learn.tensorflow.network import Chopra -from bob.learn.tensorflow.loss import MeanSoftMaxLoss, ContrastiveLoss, TripletLoss +from bob.learn.tensorflow.network import chopra +from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss, triplet_loss from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant -from .test_cnn_scratch import validate_network +from bob.learn.tensorflow.test.test_cnn_scratch import validate_network from bob.learn.tensorflow.network import Embedding, LightCNN9 +from bob.learn.tensorflow.network.utils import append_logits + from bob.learn.tensorflow.utils import load_mnist import tensorflow as tf @@ -92,15 +94,15 @@ def test_cnn_trainer(): directory = "./temp/cnn" + # Preparing the graph + inputs = train_data_shuffler("data", from_queue=True) + labels = train_data_shuffler("label", from_queue=True) + logits = append_logits(chopra(inputs, seed=seed)[0], n_classes=10) + # Loss for the softmax - loss = MeanSoftMaxLoss() - - # Preparing the architecture - architecture = Chopra(seed=seed, n_classes=10) - input_pl = train_data_shuffler("data", from_queue=True) + loss = mean_cross_entropy_loss(logits, labels) - graph = architecture(input_pl) - embedding = Embedding(train_data_shuffler("data", from_queue=False), graph) + embedding = Embedding(train_data_shuffler("data", from_queue=False), logits) # One graph trainer trainer = Trainer(train_data_shuffler, @@ -108,7 +110,7 @@ def test_cnn_trainer(): analizer=None, temp_dir=directory ) - trainer.create_network_from_scratch(graph=graph, + trainer.create_network_from_scratch(graph=logits, loss=loss, learning_rate=constant(0.01, name="regular_lr"), optimizer=tf.train.GradientDescentOptimizer(0.01), @@ -122,7 +124,7 @@ def test_cnn_trainer(): assert accuracy > 20. shutil.rmtree(directory) del trainer - del graph + del logits tf.reset_default_graph() assert len(tf.global_variables())==0 @@ -139,7 +141,6 @@ def test_lightcnn_trainer(): validation_data = numpy.vstack((validation_data, numpy.random.normal(2, 0.2, size=(100, 128, 128, 1)))) validation_labels = numpy.hstack((numpy.zeros(100), numpy.ones(100))).astype("uint64") - # Creating datashufflers data_augmentation = ImageAugmentation() train_data_shuffler = Memory(train_data, train_labels, @@ -150,15 +151,17 @@ def test_lightcnn_trainer(): directory = "./temp/cnn" - # Loss for the softmax - loss = MeanSoftMaxLoss() - # Preparing the architecture architecture = LightCNN9(seed=seed, n_classes=2) - input_pl = train_data_shuffler("data", from_queue=True) - graph = architecture(input_pl, end_point="logits") - embedding = Embedding(train_data_shuffler("data", from_queue=False), graph) + inputs = train_data_shuffler("data", from_queue=True) + labels = train_data_shuffler("label", from_queue=True) + logits = architecture(inputs, end_point="logits") + embedding = Embedding(train_data_shuffler("data", from_queue=False), logits) + + # Loss for the softmax + loss = mean_cross_entropy_loss(logits, labels) + # One graph trainer trainer = Trainer(train_data_shuffler, @@ -166,7 +169,7 @@ def test_lightcnn_trainer(): analizer=None, temp_dir=directory ) - trainer.create_network_from_scratch(graph=graph, + trainer.create_network_from_scratch(graph=logits, loss=loss, learning_rate=constant(0.001, name="regular_lr"), optimizer=tf.train.GradientDescentOptimizer(0.001), @@ -179,7 +182,7 @@ def test_lightcnn_trainer(): assert True shutil.rmtree(directory) del trainer - del graph + del logits tf.reset_default_graph() assert len(tf.global_variables())==0 @@ -202,16 +205,15 @@ def test_siamesecnn_trainer(): normalizer=ScaleFactor()) directory = "./temp/siamesecnn" - # Preparing the architecture - architecture = Chopra(seed=seed) + # Building the graph + inputs = train_data_shuffler("data") + labels = train_data_shuffler("label") + graph = dict() + graph['left'] = chopra(inputs['left'])[0] + graph['right'] = chopra(inputs['right'], reuse=True)[0] # Loss for the Siamese - loss = ContrastiveLoss(contrastive_margin=4.) - - input_pl = train_data_shuffler("data") - graph = dict() - graph['left'] = architecture(input_pl['left'], end_point="fc1") - graph['right'] = architecture(input_pl['right'], reuse=True, end_point="fc1") + loss = contrastive_loss(graph['left'], graph['right'], labels, contrastive_margin=4.) trainer = SiameseTrainer(train_data_shuffler, iterations=iterations, @@ -229,7 +231,6 @@ def test_siamesecnn_trainer(): assert eer < 0.15 shutil.rmtree(directory) - del architecture del trainer # Just to clean tf.variables tf.reset_default_graph() assert len(tf.global_variables())==0 @@ -254,17 +255,14 @@ def test_tripletcnn_trainer(): directory = "./temp/tripletcnn" - # Preparing the architecture - architecture = Chopra(seed=seed, fc1_output=10) - - # Loss for the Siamese - loss = TripletLoss(margin=4.) - - input_pl = train_data_shuffler("data") + inputs = train_data_shuffler("data") + labels = train_data_shuffler("label") graph = dict() - graph['anchor'] = architecture(input_pl['anchor'], end_point="fc1") - graph['positive'] = architecture(input_pl['positive'], reuse=True, end_point="fc1") - graph['negative'] = architecture(input_pl['negative'], reuse=True, end_point="fc1") + graph['anchor'] = chopra(inputs['anchor'])[0] + graph['positive'] = chopra(inputs['positive'], reuse=True)[0] + graph['negative'] = chopra(inputs['negative'], reuse=True)[0] + + loss = triplet_loss(graph['anchor'], graph['positive'], graph['negative']) # One graph trainer trainer = TripletTrainer(train_data_shuffler, @@ -283,7 +281,6 @@ def test_tripletcnn_trainer(): assert eer < 0.15 shutil.rmtree(directory) - del architecture del trainer # Just to clean tf.variables tf.reset_default_graph() assert len(tf.global_variables())==0 diff --git a/bob/learn/tensorflow/test/test_cnn_other_losses.py b/bob/learn/tensorflow/test/test_cnn_other_losses.py index bc5f3ae7..dfcbc34e 100755 --- a/bob/learn/tensorflow/test/test_cnn_other_losses.py +++ b/bob/learn/tensorflow/test/test_cnn_other_losses.py @@ -5,9 +5,10 @@ import numpy from bob.learn.tensorflow.datashuffler import TFRecord -from bob.learn.tensorflow.loss import MeanSoftMaxLossCenterLoss, MeanSoftMaxLoss +from bob.learn.tensorflow.loss import mean_cross_entropy_loss, mean_cross_entropy_center_loss from bob.learn.tensorflow.trainers import Trainer, constant from bob.learn.tensorflow.utils import load_mnist +from bob.learn.tensorflow.network.utils import append_logits import tensorflow as tf import shutil import os @@ -25,7 +26,7 @@ directory = "./temp/cnn_scratch" slim = tf.contrib.slim -def scratch_network_embeding_example(train_data_shuffler, reuse=False, get_embedding=False): +def scratch_network_embeding_example(train_data_shuffler, reuse=False): if isinstance(train_data_shuffler, tf.Tensor): inputs = train_data_shuffler @@ -41,19 +42,7 @@ def scratch_network_embeding_example(train_data_shuffler, reuse=False, get_embed prelogits = slim.fully_connected(graph, 30, activation_fn=None, scope='fc1', weights_initializer=initializer, reuse=reuse) - if get_embedding: - embedding = tf.nn.l2_normalize(prelogits, dim=1, name="embedding") - return embedding, None - else: - logits = slim.fully_connected(prelogits, 10, activation_fn=None, scope='logits', - weights_initializer=initializer, reuse=reuse) - - #logits_prelogits = dict() - #logits_prelogits['logits'] = logits - #logits_prelogits['prelogits'] = prelogits - - return logits, prelogits - + return prelogits def test_center_loss_tfrecord_embedding_validation(): tf.reset_default_graph() @@ -95,6 +84,7 @@ def test_center_loss_tfrecord_embedding_validation(): create_tf_record(tfrecords_filename_val, validation_data, validation_labels) filename_queue_val = tf.train.string_input_producer([tfrecords_filename_val], num_epochs=55, name="input_validation") + # Creating the CNN using the TFRecord as input train_data_shuffler = TFRecord(filename_queue=filename_queue, batch_size=batch_size) @@ -102,12 +92,15 @@ def test_center_loss_tfrecord_embedding_validation(): validation_data_shuffler = TFRecord(filename_queue=filename_queue_val, batch_size=2000) - graph, prelogits = scratch_network_embeding_example(train_data_shuffler) - validation_graph,_ = scratch_network_embeding_example(validation_data_shuffler, reuse=True, get_embedding=True) + prelogits = scratch_network_embeding_example(train_data_shuffler) + logits = append_logits(prelogits, n_classes=10) + validation_graph = tf.nn.l2_normalize(scratch_network_embeding_example(validation_data_shuffler, reuse=True), 1) + + labels = train_data_shuffler("label", from_queue=False) # Setting the placeholders # Loss for the softmax - loss = MeanSoftMaxLossCenterLoss(n_classes=10, factor=0.1) + loss = mean_cross_entropy_center_loss(logits, prelogits, labels, n_classes=10, factor=0.1) # One graph trainer trainer = Trainer(train_data_shuffler, @@ -119,14 +112,13 @@ def test_center_loss_tfrecord_embedding_validation(): learning_rate = constant(0.01, name="regular_lr") - trainer.create_network_from_scratch(graph=graph, + trainer.create_network_from_scratch(graph=logits, validation_graph=validation_graph, loss=loss, learning_rate=learning_rate, optimizer=tf.train.GradientDescentOptimizer(learning_rate), prelogits=prelogits ) - trainer.train() assert True @@ -155,26 +147,8 @@ def test_center_loss_tfrecord_embedding_validation(): temp_dir=directory) trainer.create_network_from_file(directory) - - import ipdb; ipdb.set_trace(); - trainer.train() - """ - - # Inference. TODO: Wrap this in a package - file_name = os.path.join(directory, "model.ckp.meta") - images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1)) - graph ,_ = scratch_network_embeding_example(images, reuse=False) - - session = tf.Session() - session.run(tf.global_variables_initializer()) - saver = tf.train.import_meta_graph(file_name, clear_devices=True) - saver.restore(session, tf.train.latest_checkpoint(os.path.dirname("./temp/cnn_scratch/"))) - data = numpy.random.rand(2, 28, 28, 1).astype("float32") - assert session.run(graph, feed_dict={images: data}).shape == (2, 10) - """ - os.remove(tfrecords_filename) os.remove(tfrecords_filename_val) diff --git a/bob/learn/tensorflow/test/test_cnn_scratch.py b/bob/learn/tensorflow/test/test_cnn_scratch.py index 836689f7..04b3d6f3 100755 --- a/bob/learn/tensorflow/test/test_cnn_scratch.py +++ b/bob/learn/tensorflow/test/test_cnn_scratch.py @@ -6,7 +6,7 @@ import numpy from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor, Linear, TFRecord from bob.learn.tensorflow.network import Embedding -from bob.learn.tensorflow.loss import BaseLoss, MeanSoftMaxLoss +from bob.learn.tensorflow.loss import mean_cross_entropy_loss, contrastive_loss, triplet_loss from bob.learn.tensorflow.trainers import Trainer, constant from bob.learn.tensorflow.utils import load_mnist import tensorflow as tf diff --git a/bob/learn/tensorflow/trainers/SiameseTrainer.py b/bob/learn/tensorflow/trainers/SiameseTrainer.py index cde9bb5a..9beba407 100755 --- a/bob/learn/tensorflow/trainers/SiameseTrainer.py +++ b/bob/learn/tensorflow/trainers/SiameseTrainer.py @@ -97,8 +97,6 @@ class SiameseTrainer(Trainer): self.validation_graph = None self.loss = None - - self.predictor = None self.validation_predictor = None self.optimizer_class = None @@ -139,9 +137,6 @@ class SiameseTrainer(Trainer): raise ValueError("`graph` should be a dictionary with two elements (`left`and `right`)") self.loss = loss - self.predictor = self.loss(self.label_ph, - self.graph["left"], - self.graph["right"]) self.optimizer_class = optimizer self.learning_rate = learning_rate @@ -156,9 +151,9 @@ class SiameseTrainer(Trainer): tf.add_to_collection("graph_right", self.graph['right']) # Saving pointers to the loss - tf.add_to_collection("predictor_loss", self.predictor['loss']) - tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class']) - tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class']) + tf.add_to_collection("loss", self.loss['loss']) + tf.add_to_collection("between_class_loss", self.loss['between_class']) + tf.add_to_collection("within_class_loss", self.loss['within_class']) # Saving the pointers to the placeholders tf.add_to_collection("data_ph_left", self.data_ph['left']) @@ -167,7 +162,7 @@ class SiameseTrainer(Trainer): # Preparing the optimizer self.optimizer_class._learning_rate = self.learning_rate - self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step) + self.optimizer = self.optimizer_class.minimize(self.loss['loss'], global_step=self.global_step) tf.add_to_collection("optimizer", self.optimizer) tf.add_to_collection("learning_rate", self.learning_rate) @@ -196,10 +191,10 @@ class SiameseTrainer(Trainer): self.label_ph = tf.get_collection("label_ph")[0] # Loading loss from the pointers - self.predictor = dict() - self.predictor['loss'] = tf.get_collection("predictor_loss")[0] - self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0] - self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0] + self.loss = dict() + self.loss['loss'] = tf.get_collection("loss")[0] + self.loss['between_class'] = tf.get_collection("between_class_loss")[0] + self.loss['within_class'] = tf.get_collection("within_class_loss")[0] # Loading other elements self.optimizer = tf.get_collection("optimizer")[0] @@ -223,8 +218,8 @@ class SiameseTrainer(Trainer): _, l, bt_class, wt_class, lr, summary = self.session.run([ self.optimizer, - self.predictor['loss'], self.predictor['between_class'], - self.predictor['within_class'], + self.loss['loss'], self.loss['between_class'], + self.loss['within_class'], self.learning_rate, self.summaries_train], feed_dict=feed_dict) logger.info("Loss training set step={0} = {1}".format(step, l)) @@ -238,9 +233,9 @@ class SiameseTrainer(Trainer): tf.summary.histogram(var.op.name, var) # Train summary - tf.summary.scalar('loss', self.predictor['loss']) - tf.summary.scalar('between_class_loss', self.predictor['between_class']) - tf.summary.scalar('within_class_loss', self.predictor['within_class']) + tf.summary.scalar('loss', self.loss['loss']) + tf.summary.scalar('between_class_loss', self.loss['between_class']) + tf.summary.scalar('within_class_loss', self.loss['within_class']) tf.summary.scalar('lr', self.learning_rate) return tf.summary.merge_all() @@ -257,7 +252,7 @@ class SiameseTrainer(Trainer): # Opening a new session for validation feed_dict = self.get_feed_dict(data_shuffler) - l, summary = self.session.run([self.predictor, self.summaries_validation], feed_dict=feed_dict) + l, summary = self.session.run([self.loss, self.summaries_validation], feed_dict=feed_dict) self.validation_summary_writter.add_summary(summary, step) #summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))] diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py index 03ac1670..25c660e8 100755 --- a/bob/learn/tensorflow/trainers/Trainer.py +++ b/bob/learn/tensorflow/trainers/Trainer.py @@ -103,7 +103,6 @@ class Trainer(object): self.loss = None - self.predictor = None self.validation_predictor = None self.validate_with_embeddings = validate_with_embeddings @@ -242,35 +241,32 @@ class Trainer(object): # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT self.centers = None if prelogits is not None: - self.predictor, self.centers = self.loss(self.graph, prelogits, self.label_ph) + self.loss = loss['loss'] + self.centers = loss['centers'] tf.add_to_collection("centers", self.centers) + tf.add_to_collection("loss", self.loss) tf.add_to_collection("prelogits", prelogits) self.prelogits = prelogits - else: - self.predictor = self.loss(self.graph, self.label_ph) - + self.optimizer_class = optimizer self.learning_rate = learning_rate self.global_step = tf.contrib.framework.get_or_create_global_step() # Preparing the optimizer self.optimizer_class._learning_rate = self.learning_rate - self.optimizer = self.optimizer_class.minimize(self.predictor, global_step=self.global_step) + self.optimizer = self.optimizer_class.minimize(self.loss, global_step=self.global_step) # Saving all the variables self.saver = tf.train.Saver(var_list=tf.global_variables() + tf.local_variables(), keep_checkpoint_every_n_hours=self.keep_checkpoint_every_n_hours) - self.summaries_train = self.create_general_summary(self.predictor, self.graph, self.label_ph) + self.summaries_train = self.create_general_summary(self.loss, self.graph, self.label_ph) # SAving some variables tf.add_to_collection("global_step", self.global_step) - + tf.add_to_collection("loss", self.loss) tf.add_to_collection("graph", self.graph) - - tf.add_to_collection("predictor", self.predictor) - tf.add_to_collection("data_ph", self.data_ph) tf.add_to_collection("label_ph", self.label_ph) @@ -363,7 +359,7 @@ class Trainer(object): self.label_ph = tf.get_collection("label_ph")[0] self.graph = tf.get_collection("graph")[0] - self.predictor = tf.get_collection("predictor")[0] + self.loss = tf.get_collection("loss")[0] # Loding other elements self.optimizer = tf.get_collection("optimizer")[0] @@ -418,15 +414,15 @@ class Trainer(object): if self.train_data_shuffler.prefetch: # TODO: SPECIFIC HACK FOR THE CENTER LOSS. I NEED TO FIND A CLEAN SOLUTION FOR THAT if self.centers is None: - _, l, lr, summary = self.session.run([self.optimizer, self.predictor, + _, l, lr, summary = self.session.run([self.optimizer, self.loss, self.learning_rate, self.summaries_train]) else: - _, l, lr, summary, _ = self.session.run([self.optimizer, self.predictor, + _, l, lr, summary, _ = self.session.run([self.optimizer, self.loss, self.learning_rate, self.summaries_train, self.centers]) else: feed_dict = self.get_feed_dict(self.train_data_shuffler) - _, l, lr, summary = self.session.run([self.optimizer, self.predictor, + _, l, lr, summary = self.session.run([self.optimizer, self.loss, self.learning_rate, self.summaries_train], feed_dict=feed_dict) logger.info("Loss training set step={0} = {1}".format(step, l)) diff --git a/bob/learn/tensorflow/trainers/TripletTrainer.py b/bob/learn/tensorflow/trainers/TripletTrainer.py index a8a782ba..d52d3802 100755 --- a/bob/learn/tensorflow/trainers/TripletTrainer.py +++ b/bob/learn/tensorflow/trainers/TripletTrainer.py @@ -99,7 +99,6 @@ class TripletTrainer(Trainer): self.loss = None - self.predictor = None self.validation_predictor = None self.optimizer_class = None @@ -139,9 +138,6 @@ class TripletTrainer(Trainer): raise ValueError("`graph` should be a dictionary with two elements (`anchor`, `positive` and `negative`)") self.loss = loss - self.predictor = self.loss(self.graph["anchor"], - self.graph["positive"], - self.graph["negative"]) self.optimizer_class = optimizer self.learning_rate = learning_rate @@ -158,9 +154,9 @@ class TripletTrainer(Trainer): tf.add_to_collection("graph_negative", self.graph['negative']) # Saving pointers to the loss - tf.add_to_collection("predictor_loss", self.predictor['loss']) - tf.add_to_collection("predictor_between_class_loss", self.predictor['between_class']) - tf.add_to_collection("predictor_within_class_loss", self.predictor['within_class']) + tf.add_to_collection("loss", self.loss['loss']) + tf.add_to_collection("between_class_loss", self.loss['between_class']) + tf.add_to_collection("within_class_loss", self.loss['within_class']) # Saving the pointers to the placeholders tf.add_to_collection("data_ph_anchor", self.data_ph['anchor']) @@ -169,7 +165,7 @@ class TripletTrainer(Trainer): # Preparing the optimizer self.optimizer_class._learning_rate = self.learning_rate - self.optimizer = self.optimizer_class.minimize(self.predictor['loss'], global_step=self.global_step) + self.optimizer = self.optimizer_class.minimize(self.loss['loss'], global_step=self.global_step) tf.add_to_collection("optimizer", self.optimizer) tf.add_to_collection("learning_rate", self.learning_rate) @@ -196,10 +192,10 @@ class TripletTrainer(Trainer): self.data_ph['negative'] = tf.get_collection("data_ph_negative")[0] # Loading loss from the pointers - self.predictor = dict() - self.predictor['loss'] = tf.get_collection("predictor_loss")[0] - self.predictor['between_class'] = tf.get_collection("predictor_between_class_loss")[0] - self.predictor['within_class'] = tf.get_collection("predictor_within_class_loss")[0] + self.loss = dict() + self.loss['loss'] = tf.get_collection("loss")[0] + self.loss['between_class'] = tf.get_collection("between_class_loss")[0] + self.loss['within_class'] = tf.get_collection("within_class_loss")[0] # Loading other elements self.optimizer = tf.get_collection("optimizer")[0] @@ -221,8 +217,8 @@ class TripletTrainer(Trainer): feed_dict = self.get_feed_dict(self.train_data_shuffler) _, l, bt_class, wt_class, lr, summary = self.session.run([ self.optimizer, - self.predictor['loss'], self.predictor['between_class'], - self.predictor['within_class'], + self.loss['loss'], self.loss['between_class'], + self.loss['within_class'], self.learning_rate, self.summaries_train], feed_dict=feed_dict) logger.info("Loss training set step={0} = {1}".format(step, l)) @@ -231,9 +227,9 @@ class TripletTrainer(Trainer): def create_general_summary(self): # Train summary - tf.summary.scalar('loss', self.predictor['loss']) - tf.summary.scalar('between_class_loss', self.predictor['between_class']) - tf.summary.scalar('within_class_loss', self.predictor['within_class']) + tf.summary.scalar('loss', self.loss['loss']) + tf.summary.scalar('between_class_loss', self.loss['between_class']) + tf.summary.scalar('within_class_loss', self.loss['within_class']) tf.summary.scalar('lr', self.learning_rate) return tf.summary.merge_all() -- GitLab