diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index d2b566d9d3e2966aa0aec1141e02347520b14f78..06b0398473b3ae5af37ab3734447d78066279e83 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -80,6 +80,12 @@ class Logits(estimator.Estimator): "scopes": dict({"/": "/"}), "trainable_variables": [] } + + apply_moving_averages: bool + Apply exponential moving average in the training variables and in the loss. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + By default the decay for the variable averages is 0.9999 and for the loss is 0.9 + """ def __init__(self, @@ -92,7 +98,8 @@ class Logits(estimator.Estimator): model_dir="", validation_batch_size=None, params=None, - extra_checkpoint=None): + extra_checkpoint=None, + apply_moving_averages=True): self.architecture = architecture self.optimizer = optimizer @@ -107,7 +114,6 @@ class Logits(estimator.Estimator): check_features(features) data = features['data'] key = features['key'] - # Configure the Training Op (for TRAIN mode) if mode == tf.estimator.ModeKeys.TRAIN: @@ -121,17 +127,36 @@ class Logits(estimator.Estimator): trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - # Compute Loss (for both TRAIN and EVAL modes) - self.loss = self.loss_op(logits, labels) - if self.extra_checkpoint is not None: tf.contrib.framework.init_from_checkpoint( self.extra_checkpoint["checkpoint_path"], self.extra_checkpoint["scopes"]) global_step = tf.train.get_or_create_global_step() - train_op = self.optimizer.minimize( - self.loss, global_step=global_step) + + # Compute the moving average of all individual losses and the total loss. + if apply_moving_averages: + variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) + variable_averages_op = variable_averages.apply(tf.trainable_variables()) + else: + variable_averages_op = tf.no_op(name='noop') + + with tf.control_dependencies([variable_averages_op]): + + # Compute Loss (for both TRAIN and EVAL modes) + self.loss = self.loss_op(logits, labels) + + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + loss_averages_op = loss_averages.apply(tf.get_collection(tf.GraphKeys.LOSSES)) + + for l in tf.get_collection(tf.GraphKeys.LOSSES): + tf.summary.scalar(l.op.name+"_averaged", loss_averages.average(l)) + + global_step = tf.train.get_or_create_global_step() + train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step), + variable_averages_op, loss_averages_op) + return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, train_op=train_op) @@ -250,6 +275,13 @@ class LogitsCenterLoss(estimator.Estimator): "scopes": dict({"/": "/"}), "trainable_variables": [] } + + apply_moving_averages: bool + Apply exponential moving average in the training variables and in the loss. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + By default the decay for the variable averages is 0.9999 and for the loss is 0.9 + + """ @@ -266,6 +298,7 @@ class LogitsCenterLoss(estimator.Estimator): validation_batch_size=None, params=None, extra_checkpoint=None, + apply_moving_averages=True ): self.architecture = architecture @@ -307,27 +340,43 @@ class LogitsCenterLoss(estimator.Estimator): trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - # Compute Loss (for TRAIN mode) - loss_dict = mean_cross_entropy_center_loss( - logits, - prelogits, - labels, - self.n_classes, - alpha=self.alpha, - factor=self.factor) - - self.loss = loss_dict['loss'] - centers = loss_dict['centers'] - - if self.extra_checkpoint is not None: - tf.contrib.framework.init_from_checkpoint( - self.extra_checkpoint["checkpoint_path"], - self.extra_checkpoint["scopes"]) - global_step = tf.train.get_or_create_global_step() - train_op = tf.group( - self.optimizer.minimize( - self.loss, global_step=global_step), centers) + + # Compute the moving average of all individual losses and the total loss. + if apply_moving_averages: + variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) + variable_averages_op = variable_averages.apply(tf.trainable_variables()) + else: + variable_averages_op = tf.no_op(name='noop') + + with tf.control_dependencies([variable_averages_op]): + # Compute Loss (for TRAIN mode) + loss_dict = mean_cross_entropy_center_loss( + logits, + prelogits, + labels, + self.n_classes, + alpha=self.alpha, + factor=self.factor) + + self.loss = loss_dict['loss'] + centers = loss_dict['centers'] + + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + loss_averages_op = loss_averages.apply(tf.get_collection(tf.GraphKeys.LOSSES)) + + for l in tf.get_collection(tf.GraphKeys.LOSSES): + tf.summary.scalar(l.op.name, loss_averages.average(l)) + + if self.extra_checkpoint is not None: + tf.contrib.framework.init_from_checkpoint( + self.extra_checkpoint["checkpoint_path"], + self.extra_checkpoint["scopes"]) + + train_op = tf.group( + self.optimizer.minimize( + self.loss, global_step=global_step), centers, variable_averages_op, loss_averages_op) return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, train_op=train_op) diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py index 4dc7f583fcf57d29b5b25ed55cd48a9c91b04aab..6053b8b543d334c54c91df719bfb51d7a1f47914 100644 --- a/bob/learn/tensorflow/loss/BaseLoss.py +++ b/bob/learn/tensorflow/loss/BaseLoss.py @@ -22,19 +22,22 @@ def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True): """ with tf.variable_scope('cross_entropy_loss'): - - loss = tf.reduce_mean( + cross_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels), - name=tf.GraphKeys.LOSSES) + name="cross_entropy_loss") + + tf.summary.scalar('cross_entropy_loss', cross_loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss) if add_regularization_losses: regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) - return tf.add_n( - [loss] + regularization_losses, name=tf.GraphKeys.LOSSES) + + total_loss = tf.add_n([cross_loss] + regularization_losses, name="total_loss") + return total_loss else: - return loss + return cross_loss def mean_cross_entropy_center_loss(logits, @@ -58,12 +61,12 @@ def mean_cross_entropy_center_loss(logits, """ # Cross entropy with tf.variable_scope('cross_entropy_loss'): - loss = tf.reduce_mean( + cross_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels), - name=tf.GraphKeys.LOSSES) - - tf.summary.scalar('cross_entropy_loss', loss) + name="cross_entropy_loss") + tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss) + tf.summary.scalar('cross_entropy_loss', cross_loss) # Appending center loss with tf.variable_scope('center_loss'): @@ -89,7 +92,8 @@ def mean_cross_entropy_center_loss(logits, regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = tf.add_n( - [loss] + regularization_losses, name=tf.GraphKeys.LOSSES) + [cross_loss] + regularization_losses, name="total_loss") + tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss) tf.summary.scalar('total_loss', total_loss) loss = dict() diff --git a/bob/learn/tensorflow/network/InceptionResnetV1.py b/bob/learn/tensorflow/network/InceptionResnetV1.py index 3ff4259dc1690201cdefbdccac8af39a2557f47f..f65c323a12b262ca229a3d45149007a01c9c9a71 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV1.py +++ b/bob/learn/tensorflow/network/InceptionResnetV1.py @@ -24,7 +24,7 @@ from __future__ import print_function import tensorflow as tf import tensorflow.contrib.slim as slim - +from .utils import is_trainable # Inception-Renset-A def block35(net, @@ -254,13 +254,46 @@ def reduction_b(net, reuse=None, trainable_variables=True): 3) return net +def inception_resnet_v1_batch_norm(inputs, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV1', + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=None, + weight_decay=1e-5, + **kwargs): + """ + Creates the Inception Resnet V1 model applying batch not to each + Convolutional and FullyConnected layer. + + Parameters + ---------- + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + num_classes: number of predicted classes. + + is_training: whether is training or not. + + dropout_keep_prob: float, the fraction to keep before final layer. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True + + Returns + ------- + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + + """ + -def inference(images, - keep_probability, - phase_train=True, - bottleneck_layer_size=128, - weight_decay=0.0, - reuse=None): batch_norm_params = { # Decay for the moving averages. 'decay': 0.995, @@ -271,19 +304,20 @@ def inference(images, # Moving averages ends up in the trainable variables collection 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], } - + with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), weights_regularizer=slim.l2_regularizer(weight_decay), normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): - return inception_resnet_v1( - images, - is_training=phase_train, - dropout_keep_prob=keep_probability, - bottleneck_layer_size=bottleneck_layer_size, - reuse=reuse) + return inception_resnet_v1(inputs, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV1', + mode=mode, + trainable_variables=None,) def inception_resnet_v1(inputs, @@ -292,25 +326,35 @@ def inception_resnet_v1(inputs, reuse=None, scope='InceptionResnetV1', mode=tf.estimator.ModeKeys.TRAIN, - trainable_variables=True, - **kwargs): + trainable_variables=None, + **kwargs): """ Creates the Inception Resnet V1 model. - **Parameters** - - inputs: - a 4-D tensor of size [batch_size, height, width, 3]. - num_classes: - number of predicted classes. - mode: - whether is training or not. - dropout_keep_prob: - the fraction to keep before final layer. - reuse: - whether or not the network and its variables should be reused. To be able to reuse 'scope' must be given. - scope: - Optional variable_scope. + Parameters + ---------- + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + num_classes: number of predicted classes. + + is_training: whether is training or not. + + dropout_keep_prob: float, the fraction to keep before final layer. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True + + Returns + ------- + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + """ end_points = {} @@ -318,127 +362,164 @@ def inception_resnet_v1(inputs, with slim.arg_scope( [slim.batch_norm, slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): + + with slim.arg_scope( [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'): # 149 x 149 x 32 + name = "Conv2d_1a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( inputs, 32, 3, stride=2, padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable_variables) - end_points['Conv2d_1a_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 147 x 147 x 32 + name = "Conv2d_2a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 32, 3, padding='VALID', - scope='Conv2d_2a_3x3', - trainable=trainable_variables) - end_points['Conv2d_2a_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 147 x 147 x 64 + name = "Conv2d_2b_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 64, 3, - scope='Conv2d_2b_3x3', - trainable=trainable_variables) - end_points['Conv2d_2b_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 73 x 73 x 64 net = slim.max_pool2d( net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') end_points['MaxPool_3a_3x3'] = net + # 73 x 73 x 80 + name = "Conv2d_3b_1x1" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 80, 1, padding='VALID', - scope='Conv2d_3b_1x1', - trainable=trainable_variables) - end_points['Conv2d_3b_1x1'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 71 x 71 x 192 + name = "Conv2d_4a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 192, 3, padding='VALID', - scope='Conv2d_4a_3x3', - trainable=trainable_variables) - end_points['Conv2d_4a_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 35 x 35 x 256 + name = "Conv2d_4b_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 256, 3, stride=2, padding='VALID', - scope='Conv2d_4b_3x3', - trainable=trainable_variables) - end_points['Conv2d_4b_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 5 x Inception-resnet-A + name = "block35" + trainable = is_trainable(name, trainable_variables) net = slim.repeat( net, 5, block35, scale=0.17, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_5a'] = net + end_points[name] = net # Reduction-A - with tf.variable_scope('Mixed_6a'): + name = "Mixed_6a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): net = reduction_a( net, 192, 192, 256, 384, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_6a'] = net + end_points[name] = net # 10 x Inception-Resnet-B + name = "block17" + trainable = is_trainable(name, trainable_variables) net = slim.repeat( net, 10, block17, scale=0.10, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_6b'] = net + end_points[name] = net # Reduction-B - with tf.variable_scope('Mixed_7a'): + name = "Mixed_7a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): net = reduction_b( net, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_7a'] = net + end_points[name] = net # 5 x Inception-Resnet-C + name = "block8" + trainable = is_trainable(name, trainable_variables) net = slim.repeat( net, 5, block8, scale=0.20, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_8a'] = net + end_points[name] = net + name = "Mixed_8b" + trainable = is_trainable(name, trainable_variables) net = block8( net, activation_fn=None, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_8b'] = net + end_points[name] = net with tf.variable_scope('Logits'): end_points['PrePool'] = net @@ -458,12 +539,14 @@ def inception_resnet_v1(inputs, end_points['PreLogitsFlatten'] = net + name = "Bottleneck" + trainable = is_trainable(name, trainable_variables) net = slim.fully_connected( net, bottleneck_layer_size, activation_fn=None, - scope='Bottleneck', + scope=name, reuse=reuse, - trainable=trainable_variables) + trainable=trainable) return net, end_points diff --git a/bob/learn/tensorflow/network/InceptionResnetV2.py b/bob/learn/tensorflow/network/InceptionResnetV2.py index aa5ea67f86de8b8e597465b79baf85ef3d06a482..b52fcc48c2a3638fb80b5e86a7c061243f08efcc 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV2.py +++ b/bob/learn/tensorflow/network/InceptionResnetV2.py @@ -203,11 +203,45 @@ def block8(net, return net -def inference(images, - keep_probability, - bottleneck_layer_size=128, - weight_decay=0.0, - reuse=None): +def inception_resnet_v2_batch_norm(inputs, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV2', + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=None, + weight_decay = 5e-5, + **kwargs): + """ + Creates the Inception Resnet V2 model applying batch not to each + Convolutional and FullyConnected layer. + + + **Parameters**: + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + num_classes: number of predicted classes. + + is_training: whether is training or not. + + dropout_keep_prob: float, the fraction to keep before final layer. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True + + **Returns**: + + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + """ + + batch_norm_params = { # Decay for the moving averages. 'decay': 0.995, @@ -218,6 +252,7 @@ def inference(images, # Moving averages ends up in the trainable variables collection 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], } + with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), @@ -225,11 +260,13 @@ def inference(images, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): return inception_resnet_v2( - images, - mode=tf.estimator.ModeKeys.PREDICT, - dropout_keep_prob=keep_probability, + inputs, + dropout_keep_prob=dropout_keep_prob, bottleneck_layer_size=bottleneck_layer_size, - reuse=reuse) + reuse=reuse, + scope=scope, + mode=mode, + trainable_variables=trainable_variables) def inception_resnet_v2(inputs, @@ -242,7 +279,8 @@ def inception_resnet_v2(inputs, **kwargs): """Creates the Inception Resnet V2 model. - **Parameters**: + Parameters + ---------- inputs: a 4-D tensor of size [batch_size, height, width, 3]. @@ -260,347 +298,360 @@ def inception_resnet_v2(inputs, trainable_variables: list List of variables to be trainable=True - **Returns**: - + Returns + ------- logits: the logits outputs of the model. end_points: the set of end_points from the inception model. """ end_points = {} + + batch_norm_params = { + # Decay for the moving averages. + 'decay': 0.995, + # epsilon to prevent 0s in variance. + 'epsilon': 0.001, + # force in-place updates of mean and variance estimates + 'updates_collections': None, + # Moving averages ends up in the trainable variables collection + 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], + } + with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse): with slim.arg_scope( [slim.batch_norm, slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): - with slim.arg_scope( - [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], - stride=1, - padding='SAME'): - # 149 x 149 x 32 - name = "Conv2d_1a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - inputs, - 32, - 3, - stride=2, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 147 x 147 x 32 - name = "Conv2d_2a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 32, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 147 x 147 x 64 - name = "Conv2d_2b_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 64, 3, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net - - # 73 x 73 x 64 - net = slim.max_pool2d( - net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') - end_points['MaxPool_3a_3x3'] = net - - # 73 x 73 x 80 - name = "Conv2d_3b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 80, - 1, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 71 x 71 x 192 - name = "Conv2d_4a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 192, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 35 x 35 x 192 - net = slim.max_pool2d( - net, 3, stride=2, padding='VALID', scope='MaxPool_5a_3x3') - end_points['MaxPool_5a_3x3'] = net - - # 35 x 35 x 320 - name = "Mixed_5b" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 96, - 1, - scope='Conv2d_1x1', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d( - net, - 48, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1_0, - 64, - 5, - scope='Conv2d_0b_5x5', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_conv2_0 = slim.conv2d( - net, - 64, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv2_1 = slim.conv2d( - tower_conv2_0, - 96, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv2_2 = slim.conv2d( - tower_conv2_1, - 96, - 3, - scope='Conv2d_0c_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_3'): - tower_pool = slim.avg_pool2d( - net, - 3, - stride=1, - padding='SAME', - scope='AvgPool_0a_3x3') - tower_pool_1 = slim.conv2d( - tower_pool, - 64, - 1, - scope='Conv2d_0b_1x1', - trainable=trainable, - reuse=reuse) - net = tf.concat([ - tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1 - ], 3) - end_points[name] = net - - # BLOCK 35 - name = "Block35" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 10, - block35, - scale=0.17, - trainable_variables=trainable, - reuse=reuse) - - # 17 x 17 x 1024 - name = "Mixed_6a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1_0, - 256, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv1_2 = slim.conv2d( - tower_conv1_1, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_pool = slim.max_pool2d( - net, - 3, - stride=2, - padding='VALID', - scope='MaxPool_1a_3x3') - net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) - - end_points[name] = net - - # BLOCK 17 - name = "Block17" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 20, - block17, - scale=0.10, - trainable_variables=trainable, - reuse=reuse) - - name = "Mixed_7a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv_1 = slim.conv2d( - tower_conv, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1, - 288, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_conv2 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv2_1 = slim.conv2d( - tower_conv2, - 288, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv2_2 = slim.conv2d( - tower_conv2_1, - 320, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_3'): - tower_pool = slim.max_pool2d( + + with slim.arg_scope( + [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], + stride=1, + padding='SAME'): + # 149 x 149 x 32 + name = "Conv2d_1a_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + inputs, + 32, + 3, + stride=2, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 147 x 147 x 32 + name = "Conv2d_2a_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, + 32, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 147 x 147 x 64 + name = "Conv2d_2b_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, 64, 3, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + + # 73 x 73 x 64 + net = slim.max_pool2d( + net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') + end_points['MaxPool_3a_3x3'] = net + + # 73 x 73 x 80 + name = "Conv2d_3b_1x1" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, + 80, + 1, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 71 x 71 x 192 + name = "Conv2d_4a_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, + 192, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 35 x 35 x 192 + net = slim.max_pool2d( + net, 3, stride=2, padding='VALID', scope='MaxPool_5a_3x3') + end_points['MaxPool_5a_3x3'] = net + + # 35 x 35 x 320 + name = "Mixed_5b" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 96, + 1, + scope='Conv2d_1x1', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d( + net, + 48, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1_0, + 64, + 5, + scope='Conv2d_0b_5x5', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_conv2_0 = slim.conv2d( + net, + 64, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv2_1 = slim.conv2d( + tower_conv2_0, + 96, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv2_2 = slim.conv2d( + tower_conv2_1, + 96, + 3, + scope='Conv2d_0c_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_3'): + tower_pool = slim.avg_pool2d( + net, + 3, + stride=1, + padding='SAME', + scope='AvgPool_0a_3x3') + tower_pool_1 = slim.conv2d( + tower_pool, + 64, + 1, + scope='Conv2d_0b_1x1', + trainable=trainable, + reuse=reuse) + net = tf.concat([ + tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1 + ], 3) + end_points[name] = net + + # BLOCK 35 + name = "Block35" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat( + net, + 10, + block35, + scale=0.17, + trainable_variables=trainable, + reuse=reuse) + + # 17 x 17 x 1024 + name = "Mixed_6a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1_0, + 256, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv1_2 = slim.conv2d( + tower_conv1_1, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_pool = slim.max_pool2d( + net, + 3, + stride=2, + padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) + + end_points[name] = net + + # BLOCK 17 + name = "Block17" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat( + net, + 20, + block17, + scale=0.10, + trainable_variables=trainable, + reuse=reuse) + + name = "Mixed_7a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv_1 = slim.conv2d( + tower_conv, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1, + 288, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_conv2 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv2_1 = slim.conv2d( + tower_conv2, + 288, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv2_2 = slim.conv2d( + tower_conv2_1, + 320, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_3'): + tower_pool = slim.max_pool2d( + net, + 3, + stride=2, + padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([ + tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool + ], 3) + end_points[name] = net + + # Block 8 + name = "Block8" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat( + net, + 9, + block8, + scale=0.20, + trainable_variables=trainable, + reuse=reuse) + net = block8( + net, + activation_fn=None, + trainable_variables=trainable, + reuse=reuse) + + name = "Conv2d_7b_1x1" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + + with tf.variable_scope('Logits'): + end_points['PrePool'] = net + # pylint: disable=no-member + net = slim.avg_pool2d( net, - 3, - stride=2, + net.get_shape()[1:3], padding='VALID', - scope='MaxPool_1a_3x3') - net = tf.concat([ - tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool - ], 3) - end_points[name] = net - - # Block 8 - name = "Block8" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 9, - block8, - scale=0.20, - trainable_variables=trainable, - reuse=reuse) - net = block8( - net, - activation_fn=None, - trainable_variables=trainable, - reuse=reuse) - - name = "Conv2d_7b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net - - with tf.variable_scope('Logits'): - end_points['PrePool'] = net - # pylint: disable=no-member - net = slim.avg_pool2d( + scope='AvgPool_1a_8x8') + net = slim.flatten(net) + + net = slim.dropout(net, dropout_keep_prob, scope='Dropout') + + end_points['PreLogitsFlatten'] = net + + name = "Bottleneck" + trainable = is_trainable(name, trainable_variables) + net = slim.fully_connected( net, - net.get_shape()[1:3], - padding='VALID', - scope='AvgPool_1a_8x8') - net = slim.flatten(net) - - net = slim.dropout(net, dropout_keep_prob, scope='Dropout') - - end_points['PreLogitsFlatten'] = net - - name = "Bottleneck" - trainable = is_trainable(name, trainable_variables) - net = slim.fully_connected( - net, - bottleneck_layer_size, - activation_fn=None, - scope=name, - reuse=reuse, - trainable=trainable) - end_points[name] = net + bottleneck_layer_size, + activation_fn=None, + scope=name, + reuse=reuse, + trainable=trainable) + end_points[name] = net return net, end_points diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index 7d4656089bdd850cdea103b33a64ac2266081839..5aca74aa4be5a6ba05bd6e1c3476a413b21935ee 100644 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -3,8 +3,8 @@ from .LightCNN9 import light_cnn9 from .Dummy import dummy from .MLP import mlp from .Embedding import Embedding -from .InceptionResnetV2 import inception_resnet_v2 -from .InceptionResnetV1 import inception_resnet_v1 +from .InceptionResnetV2 import inception_resnet_v2, inception_resnet_v2_batch_norm +from .InceptionResnetV1 import inception_resnet_v1, inception_resnet_v1_batch_norm from . import SimpleCNN diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py index 5b32ed123aaaccae3ac124b215d926643d9cc8df..852db26727592b00762d1d8984190602f9716fc8 100644 --- a/bob/learn/tensorflow/network/utils.py +++ b/bob/learn/tensorflow/network/utils.py @@ -9,7 +9,7 @@ import tensorflow.contrib.slim as slim def append_logits(graph, n_classes, reuse=False, - l2_regularizer=0.001, + l2_regularizer=5e-05, weights_std=0.1): return slim.fully_connected( graph, diff --git a/bob/learn/tensorflow/test/test_architectures.py b/bob/learn/tensorflow/test/test_architectures.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6c532709a258397bc7db559f6c53cabf8a076a --- /dev/null +++ b/bob/learn/tensorflow/test/test_architectures.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira + +import tensorflow as tf +from bob.learn.tensorflow.network import inception_resnet_v2, inception_resnet_v2_batch_norm,\ + inception_resnet_v1, inception_resnet_v1_batch_norm + +def test_inceptionv2(): + + # Testing WITHOUT batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v2(inputs) + assert len(tf.trainable_variables())==490 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + # Testing WITH batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v2_batch_norm(inputs) + assert len(tf.trainable_variables())==900 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + +def test_inceptionv1(): + + # Testing WITHOUT batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v1(inputs) + assert len(tf.trainable_variables())==266 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + # Testing WITH batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v1_batch_norm(inputs) + assert len(tf.trainable_variables())==490 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 +