From 82cf9beb6fa807de1ba889f305fd13a958d4ad46 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira Date: Tue, 19 Dec 2017 10:08:52 +0100 Subject: [PATCH 1/2] Merging inception v1 Updating vgg16 Updated inceptionV2 Added VGG16 Updated VGG16 Updated VGG16 Updated VGG16 Updated inceptionV1 Inceptionv2 batch norm Updated batch norm Added moving average Added batch norm to InceptionV2 and InceptionV1 Applied moving averages Changed the default value for the l2 regularizer Fixed the loss collection Fixed issue with the mode --- bob/learn/tensorflow/estimators/Logits.py | 90 ++- bob/learn/tensorflow/loss/BaseLoss.py | 27 +- .../tensorflow/network/InceptionResnetV1.py | 211 +++-- .../tensorflow/network/InceptionResnetV2.py | 740 ++++++++++-------- bob/learn/tensorflow/network/__init__.py | 5 +- bob/learn/tensorflow/network/utils.py | 2 +- bob/learn/tensorflow/network/vgg.py | 300 +++++++ .../tensorflow/test/test_architectures.py | 45 ++ 8 files changed, 974 insertions(+), 446 deletions(-) create mode 100644 bob/learn/tensorflow/network/vgg.py create mode 100644 bob/learn/tensorflow/test/test_architectures.py diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index d2b566d..28b7832 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -92,7 +92,8 @@ class Logits(estimator.Estimator): model_dir="", validation_batch_size=None, params=None, - extra_checkpoint=None): + extra_checkpoint=None, + apply_moving_averages=True): self.architecture = architecture self.optimizer = optimizer @@ -107,7 +108,6 @@ class Logits(estimator.Estimator): check_features(features) data = features['data'] key = features['key'] - # Configure the Training Op (for TRAIN mode) if mode == tf.estimator.ModeKeys.TRAIN: @@ -121,17 +121,36 @@ class Logits(estimator.Estimator): trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - # Compute Loss (for both TRAIN and EVAL modes) - self.loss = self.loss_op(logits, labels) - if self.extra_checkpoint is not None: tf.contrib.framework.init_from_checkpoint( self.extra_checkpoint["checkpoint_path"], self.extra_checkpoint["scopes"]) global_step = tf.train.get_or_create_global_step() - train_op = self.optimizer.minimize( - self.loss, global_step=global_step) + + # Compute the moving average of all individual losses and the total loss. + if apply_moving_averages: + variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) + variable_averages_op = variable_averages.apply(tf.trainable_variables()) + else: + variable_averages_op = tf.no_op(name='noop') + + with tf.control_dependencies([variable_averages_op]): + + # Compute Loss (for both TRAIN and EVAL modes) + self.loss = self.loss_op(logits, labels) + + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + loss_averages_op = loss_averages.apply(tf.get_collection(tf.GraphKeys.LOSSES)) + + for l in tf.get_collection(tf.GraphKeys.LOSSES): + tf.summary.scalar(l.op.name+"_averaged", loss_averages.average(l)) + + global_step = tf.train.get_or_create_global_step() + train_op = tf.group(self.optimizer.minimize(self.loss, global_step=global_step), + variable_averages_op, loss_averages_op) + return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, train_op=train_op) @@ -266,6 +285,7 @@ class LogitsCenterLoss(estimator.Estimator): validation_batch_size=None, params=None, extra_checkpoint=None, + apply_moving_averages=True ): self.architecture = architecture @@ -307,27 +327,43 @@ class LogitsCenterLoss(estimator.Estimator): trainable_variables=trainable_variables)[0] logits = append_logits(prelogits, n_classes) - # Compute Loss (for TRAIN mode) - loss_dict = mean_cross_entropy_center_loss( - logits, - prelogits, - labels, - self.n_classes, - alpha=self.alpha, - factor=self.factor) - - self.loss = loss_dict['loss'] - centers = loss_dict['centers'] - - if self.extra_checkpoint is not None: - tf.contrib.framework.init_from_checkpoint( - self.extra_checkpoint["checkpoint_path"], - self.extra_checkpoint["scopes"]) - global_step = tf.train.get_or_create_global_step() - train_op = tf.group( - self.optimizer.minimize( - self.loss, global_step=global_step), centers) + + # Compute the moving average of all individual losses and the total loss. + if apply_moving_averages: + variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) + variable_averages_op = variable_averages.apply(tf.trainable_variables()) + else: + variable_averages_op = tf.no_op(name='noop') + + with tf.control_dependencies([variable_averages_op]): + # Compute Loss (for TRAIN mode) + loss_dict = mean_cross_entropy_center_loss( + logits, + prelogits, + labels, + self.n_classes, + alpha=self.alpha, + factor=self.factor) + + self.loss = loss_dict['loss'] + centers = loss_dict['centers'] + + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + loss_averages_op = loss_averages.apply(tf.get_collection(tf.GraphKeys.LOSSES)) + + for l in tf.get_collection(tf.GraphKeys.LOSSES): + tf.summary.scalar(l.op.name, loss_averages.average(l)) + + if self.extra_checkpoint is not None: + tf.contrib.framework.init_from_checkpoint( + self.extra_checkpoint["checkpoint_path"], + self.extra_checkpoint["scopes"]) + + train_op = tf.group( + self.optimizer.minimize( + self.loss, global_step=global_step), centers, variable_averages_op, loss_averages_op) return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, train_op=train_op) diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py index 4dc7f58..881c924 100644 --- a/bob/learn/tensorflow/loss/BaseLoss.py +++ b/bob/learn/tensorflow/loss/BaseLoss.py @@ -22,19 +22,23 @@ def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True): """ with tf.variable_scope('cross_entropy_loss'): - - loss = tf.reduce_mean( + cross_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels), - name=tf.GraphKeys.LOSSES) + name="cross_entropy_loss") + + tf.summary.scalar('cross_entropy_loss', cross_loss) + tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss) + if add_regularization_losses: regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) - return tf.add_n( - [loss] + regularization_losses, name=tf.GraphKeys.LOSSES) + + total_loss = tf.add_n([cross_loss] + regularization_losses, name="total_loss") + return total_loss else: - return loss + return cross_loss def mean_cross_entropy_center_loss(logits, @@ -58,12 +62,12 @@ def mean_cross_entropy_center_loss(logits, """ # Cross entropy with tf.variable_scope('cross_entropy_loss'): - loss = tf.reduce_mean( + cross_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels), - name=tf.GraphKeys.LOSSES) - - tf.summary.scalar('cross_entropy_loss', loss) + name="cross_entropy_loss") + tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss) + tf.summary.scalar('cross_entropy_loss', cross_loss) # Appending center loss with tf.variable_scope('center_loss'): @@ -89,7 +93,8 @@ def mean_cross_entropy_center_loss(logits, regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = tf.add_n( - [loss] + regularization_losses, name=tf.GraphKeys.LOSSES) + [cross_loss] + regularization_losses, name="total_loss") + tf.add_to_collection(tf.GraphKeys.LOSSES, total_loss) tf.summary.scalar('total_loss', total_loss) loss = dict() diff --git a/bob/learn/tensorflow/network/InceptionResnetV1.py b/bob/learn/tensorflow/network/InceptionResnetV1.py index 3ff4259..f65c323 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV1.py +++ b/bob/learn/tensorflow/network/InceptionResnetV1.py @@ -24,7 +24,7 @@ from __future__ import print_function import tensorflow as tf import tensorflow.contrib.slim as slim - +from .utils import is_trainable # Inception-Renset-A def block35(net, @@ -254,13 +254,46 @@ def reduction_b(net, reuse=None, trainable_variables=True): 3) return net +def inception_resnet_v1_batch_norm(inputs, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV1', + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=None, + weight_decay=1e-5, + **kwargs): + """ + Creates the Inception Resnet V1 model applying batch not to each + Convolutional and FullyConnected layer. + + Parameters + ---------- + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + num_classes: number of predicted classes. + + is_training: whether is training or not. + + dropout_keep_prob: float, the fraction to keep before final layer. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True + + Returns + ------- + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + + """ + -def inference(images, - keep_probability, - phase_train=True, - bottleneck_layer_size=128, - weight_decay=0.0, - reuse=None): batch_norm_params = { # Decay for the moving averages. 'decay': 0.995, @@ -271,19 +304,20 @@ def inference(images, # Moving averages ends up in the trainable variables collection 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], } - + with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), weights_regularizer=slim.l2_regularizer(weight_decay), normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): - return inception_resnet_v1( - images, - is_training=phase_train, - dropout_keep_prob=keep_probability, - bottleneck_layer_size=bottleneck_layer_size, - reuse=reuse) + return inception_resnet_v1(inputs, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV1', + mode=mode, + trainable_variables=None,) def inception_resnet_v1(inputs, @@ -292,25 +326,35 @@ def inception_resnet_v1(inputs, reuse=None, scope='InceptionResnetV1', mode=tf.estimator.ModeKeys.TRAIN, - trainable_variables=True, - **kwargs): + trainable_variables=None, + **kwargs): """ Creates the Inception Resnet V1 model. - **Parameters** - - inputs: - a 4-D tensor of size [batch_size, height, width, 3]. - num_classes: - number of predicted classes. - mode: - whether is training or not. - dropout_keep_prob: - the fraction to keep before final layer. - reuse: - whether or not the network and its variables should be reused. To be able to reuse 'scope' must be given. - scope: - Optional variable_scope. + Parameters + ---------- + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + num_classes: number of predicted classes. + + is_training: whether is training or not. + + dropout_keep_prob: float, the fraction to keep before final layer. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True + + Returns + ------- + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + """ end_points = {} @@ -318,127 +362,164 @@ def inception_resnet_v1(inputs, with slim.arg_scope( [slim.batch_norm, slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): + + with slim.arg_scope( [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'): # 149 x 149 x 32 + name = "Conv2d_1a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( inputs, 32, 3, stride=2, padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable_variables) - end_points['Conv2d_1a_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 147 x 147 x 32 + name = "Conv2d_2a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 32, 3, padding='VALID', - scope='Conv2d_2a_3x3', - trainable=trainable_variables) - end_points['Conv2d_2a_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 147 x 147 x 64 + name = "Conv2d_2b_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 64, 3, - scope='Conv2d_2b_3x3', - trainable=trainable_variables) - end_points['Conv2d_2b_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 73 x 73 x 64 net = slim.max_pool2d( net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') end_points['MaxPool_3a_3x3'] = net + # 73 x 73 x 80 + name = "Conv2d_3b_1x1" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 80, 1, padding='VALID', - scope='Conv2d_3b_1x1', - trainable=trainable_variables) - end_points['Conv2d_3b_1x1'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 71 x 71 x 192 + name = "Conv2d_4a_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 192, 3, padding='VALID', - scope='Conv2d_4a_3x3', - trainable=trainable_variables) - end_points['Conv2d_4a_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + # 35 x 35 x 256 + name = "Conv2d_4b_3x3" + trainable = is_trainable(name, trainable_variables) net = slim.conv2d( net, 256, 3, stride=2, padding='VALID', - scope='Conv2d_4b_3x3', - trainable=trainable_variables) - end_points['Conv2d_4b_3x3'] = net + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 5 x Inception-resnet-A + name = "block35" + trainable = is_trainable(name, trainable_variables) net = slim.repeat( net, 5, block35, scale=0.17, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_5a'] = net + end_points[name] = net # Reduction-A - with tf.variable_scope('Mixed_6a'): + name = "Mixed_6a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): net = reduction_a( net, 192, 192, 256, 384, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_6a'] = net + end_points[name] = net # 10 x Inception-Resnet-B + name = "block17" + trainable = is_trainable(name, trainable_variables) net = slim.repeat( net, 10, block17, scale=0.10, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_6b'] = net + end_points[name] = net # Reduction-B - with tf.variable_scope('Mixed_7a'): + name = "Mixed_7a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): net = reduction_b( net, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_7a'] = net + end_points[name] = net # 5 x Inception-Resnet-C + name = "block8" + trainable = is_trainable(name, trainable_variables) net = slim.repeat( net, 5, block8, scale=0.20, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_8a'] = net + end_points[name] = net + name = "Mixed_8b" + trainable = is_trainable(name, trainable_variables) net = block8( net, activation_fn=None, - trainable_variables=trainable_variables, + trainable_variables=trainable, reuse=reuse) - end_points['Mixed_8b'] = net + end_points[name] = net with tf.variable_scope('Logits'): end_points['PrePool'] = net @@ -458,12 +539,14 @@ def inception_resnet_v1(inputs, end_points['PreLogitsFlatten'] = net + name = "Bottleneck" + trainable = is_trainable(name, trainable_variables) net = slim.fully_connected( net, bottleneck_layer_size, activation_fn=None, - scope='Bottleneck', + scope=name, reuse=reuse, - trainable=trainable_variables) + trainable=trainable) return net, end_points diff --git a/bob/learn/tensorflow/network/InceptionResnetV2.py b/bob/learn/tensorflow/network/InceptionResnetV2.py index aa5ea67..a611a6a 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV2.py +++ b/bob/learn/tensorflow/network/InceptionResnetV2.py @@ -203,11 +203,45 @@ def block8(net, return net -def inference(images, - keep_probability, - bottleneck_layer_size=128, - weight_decay=0.0, - reuse=None): +def inception_resnet_v2_batch_norm(inputs, + dropout_keep_prob=0.8, + bottleneck_layer_size=128, + reuse=None, + scope='InceptionResnetV2', + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=None, + weight_decay = 5e-5, + **kwargs): + """ + Creates the Inception Resnet V2 model applying batch not to each + Convolutional and FullyConnected layer. + + + **Parameters**: + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + num_classes: number of predicted classes. + + is_training: whether is training or not. + + dropout_keep_prob: float, the fraction to keep before final layer. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + scope: Optional variable_scope. + + trainable_variables: list + List of variables to be trainable=True + + **Returns**: + + logits: the logits outputs of the model. + end_points: the set of end_points from the inception model. + """ + + batch_norm_params = { # Decay for the moving averages. 'decay': 0.995, @@ -218,6 +252,7 @@ def inference(images, # Moving averages ends up in the trainable variables collection 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], } + with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_initializer=tf.truncated_normal_initializer(stddev=0.1), @@ -225,11 +260,13 @@ def inference(images, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): return inception_resnet_v2( - images, - mode=tf.estimator.ModeKeys.PREDICT, - dropout_keep_prob=keep_probability, + inputs, + dropout_keep_prob=dropout_keep_prob, bottleneck_layer_size=bottleneck_layer_size, - reuse=reuse) + reuse=reuse, + scope=scope, + mode=mode, + trainable_variables=trainable_variables) def inception_resnet_v2(inputs, @@ -242,7 +279,8 @@ def inception_resnet_v2(inputs, **kwargs): """Creates the Inception Resnet V2 model. - **Parameters**: + Parameters + ---------- inputs: a 4-D tensor of size [batch_size, height, width, 3]. @@ -260,347 +298,367 @@ def inception_resnet_v2(inputs, trainable_variables: list List of variables to be trainable=True - **Returns**: - + Returns + ------- logits: the logits outputs of the model. end_points: the set of end_points from the inception model. """ end_points = {} + + batch_norm_params = { + # Decay for the moving averages. + 'decay': 0.995, + # epsilon to prevent 0s in variance. + 'epsilon': 0.001, + # force in-place updates of mean and variance estimates + 'updates_collections': None, + # Moving averages ends up in the trainable variables collection + 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], + } + with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse): with slim.arg_scope( [slim.batch_norm, slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): - with slim.arg_scope( - [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], - stride=1, - padding='SAME'): - # 149 x 149 x 32 - name = "Conv2d_1a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - inputs, - 32, - 3, - stride=2, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 147 x 147 x 32 - name = "Conv2d_2a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 32, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 147 x 147 x 64 - name = "Conv2d_2b_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 64, 3, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net - - # 73 x 73 x 64 - net = slim.max_pool2d( - net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') - end_points['MaxPool_3a_3x3'] = net - - # 73 x 73 x 80 - name = "Conv2d_3b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 80, - 1, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 71 x 71 x 192 - name = "Conv2d_4a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 192, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net - - # 35 x 35 x 192 - net = slim.max_pool2d( - net, 3, stride=2, padding='VALID', scope='MaxPool_5a_3x3') - end_points['MaxPool_5a_3x3'] = net - - # 35 x 35 x 320 - name = "Mixed_5b" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 96, - 1, - scope='Conv2d_1x1', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d( - net, - 48, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1_0, - 64, - 5, - scope='Conv2d_0b_5x5', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_conv2_0 = slim.conv2d( - net, - 64, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv2_1 = slim.conv2d( - tower_conv2_0, - 96, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv2_2 = slim.conv2d( - tower_conv2_1, - 96, - 3, - scope='Conv2d_0c_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_3'): - tower_pool = slim.avg_pool2d( - net, - 3, - stride=1, - padding='SAME', - scope='AvgPool_0a_3x3') - tower_pool_1 = slim.conv2d( - tower_pool, - 64, - 1, - scope='Conv2d_0b_1x1', - trainable=trainable, - reuse=reuse) - net = tf.concat([ - tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1 - ], 3) - end_points[name] = net - - # BLOCK 35 - name = "Block35" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 10, - block35, - scale=0.17, - trainable_variables=trainable, - reuse=reuse) - - # 17 x 17 x 1024 - name = "Mixed_6a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1_0, - 256, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv1_2 = slim.conv2d( - tower_conv1_1, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_pool = slim.max_pool2d( - net, - 3, - stride=2, - padding='VALID', - scope='MaxPool_1a_3x3') - net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) - - end_points[name] = net - - # BLOCK 17 - name = "Block17" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 20, - block17, - scale=0.10, - trainable_variables=trainable, - reuse=reuse) - - name = "Mixed_7a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv_1 = slim.conv2d( - tower_conv, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1, - 288, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_conv2 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv2_1 = slim.conv2d( - tower_conv2, - 288, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv2_2 = slim.conv2d( - tower_conv2_1, - 320, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_3'): - tower_pool = slim.max_pool2d( + + #with slim.arg_scope( + # [slim.conv2d, slim.fully_connected], + # weights_initializer=tf.truncated_normal_initializer(stddev=0.1), + # weights_regularizer=slim.l2_regularizer(weight_decay), + # normalizer_fn=slim.batch_norm, + # normalizer_params=batch_norm_params): + + with slim.arg_scope( + [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], + stride=1, + padding='SAME'): + # 149 x 149 x 32 + name = "Conv2d_1a_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + inputs, + 32, + 3, + stride=2, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 147 x 147 x 32 + name = "Conv2d_2a_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, + 32, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 147 x 147 x 64 + name = "Conv2d_2b_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, 64, 3, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + + # 73 x 73 x 64 + net = slim.max_pool2d( + net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') + end_points['MaxPool_3a_3x3'] = net + + # 73 x 73 x 80 + name = "Conv2d_3b_1x1" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, + 80, + 1, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 71 x 71 x 192 + name = "Conv2d_4a_3x3" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, + 192, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net + + # 35 x 35 x 192 + net = slim.max_pool2d( + net, 3, stride=2, padding='VALID', scope='MaxPool_5a_3x3') + end_points['MaxPool_5a_3x3'] = net + + # 35 x 35 x 320 + name = "Mixed_5b" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 96, + 1, + scope='Conv2d_1x1', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d( + net, + 48, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1_0, + 64, + 5, + scope='Conv2d_0b_5x5', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_conv2_0 = slim.conv2d( + net, + 64, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv2_1 = slim.conv2d( + tower_conv2_0, + 96, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv2_2 = slim.conv2d( + tower_conv2_1, + 96, + 3, + scope='Conv2d_0c_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_3'): + tower_pool = slim.avg_pool2d( + net, + 3, + stride=1, + padding='SAME', + scope='AvgPool_0a_3x3') + tower_pool_1 = slim.conv2d( + tower_pool, + 64, + 1, + scope='Conv2d_0b_1x1', + trainable=trainable, + reuse=reuse) + net = tf.concat([ + tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1 + ], 3) + end_points[name] = net + + # BLOCK 35 + name = "Block35" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat( + net, + 10, + block35, + scale=0.17, + trainable_variables=trainable, + reuse=reuse) + + # 17 x 17 x 1024 + name = "Mixed_6a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1_0, + 256, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv1_2 = slim.conv2d( + tower_conv1_1, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_pool = slim.max_pool2d( + net, + 3, + stride=2, + padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) + + end_points[name] = net + + # BLOCK 17 + name = "Block17" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat( + net, + 20, + block17, + scale=0.10, + trainable_variables=trainable, + reuse=reuse) + + name = "Mixed_7a" + trainable = is_trainable(name, trainable_variables) + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv_1 = slim.conv2d( + tower_conv, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1, + 288, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_conv2 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv2_1 = slim.conv2d( + tower_conv2, + 288, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv2_2 = slim.conv2d( + tower_conv2_1, + 320, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_3'): + tower_pool = slim.max_pool2d( + net, + 3, + stride=2, + padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([ + tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool + ], 3) + end_points[name] = net + + # Block 8 + name = "Block8" + trainable = is_trainable(name, trainable_variables) + net = slim.repeat( + net, + 9, + block8, + scale=0.20, + trainable_variables=trainable, + reuse=reuse) + net = block8( + net, + activation_fn=None, + trainable_variables=trainable, + reuse=reuse) + + name = "Conv2d_7b_1x1" + trainable = is_trainable(name, trainable_variables) + net = slim.conv2d( + net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + + with tf.variable_scope('Logits'): + end_points['PrePool'] = net + # pylint: disable=no-member + net = slim.avg_pool2d( net, - 3, - stride=2, + net.get_shape()[1:3], padding='VALID', - scope='MaxPool_1a_3x3') - net = tf.concat([ - tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool - ], 3) - end_points[name] = net - - # Block 8 - name = "Block8" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 9, - block8, - scale=0.20, - trainable_variables=trainable, - reuse=reuse) - net = block8( - net, - activation_fn=None, - trainable_variables=trainable, - reuse=reuse) - - name = "Conv2d_7b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net - - with tf.variable_scope('Logits'): - end_points['PrePool'] = net - # pylint: disable=no-member - net = slim.avg_pool2d( + scope='AvgPool_1a_8x8') + net = slim.flatten(net) + + net = slim.dropout(net, dropout_keep_prob, scope='Dropout') + + end_points['PreLogitsFlatten'] = net + + name = "Bottleneck" + trainable = is_trainable(name, trainable_variables) + net = slim.fully_connected( net, - net.get_shape()[1:3], - padding='VALID', - scope='AvgPool_1a_8x8') - net = slim.flatten(net) - - net = slim.dropout(net, dropout_keep_prob, scope='Dropout') - - end_points['PreLogitsFlatten'] = net - - name = "Bottleneck" - trainable = is_trainable(name, trainable_variables) - net = slim.fully_connected( - net, - bottleneck_layer_size, - activation_fn=None, - scope=name, - reuse=reuse, - trainable=trainable) - end_points[name] = net + bottleneck_layer_size, + activation_fn=None, + scope=name, + reuse=reuse, + trainable=trainable) + end_points[name] = net return net, end_points diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index 7d46560..5ea2dc4 100644 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -3,8 +3,9 @@ from .LightCNN9 import light_cnn9 from .Dummy import dummy from .MLP import mlp from .Embedding import Embedding -from .InceptionResnetV2 import inception_resnet_v2 -from .InceptionResnetV1 import inception_resnet_v1 +from .InceptionResnetV2 import inception_resnet_v2, inception_resnet_v2_batch_norm +from .InceptionResnetV1 import inception_resnet_v1, inception_resnet_v1_batch_norm +from .vgg import vgg_16 from . import SimpleCNN diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py index 5b32ed1..852db26 100644 --- a/bob/learn/tensorflow/network/utils.py +++ b/bob/learn/tensorflow/network/utils.py @@ -9,7 +9,7 @@ import tensorflow.contrib.slim as slim def append_logits(graph, n_classes, reuse=False, - l2_regularizer=0.001, + l2_regularizer=5e-05, weights_std=0.1): return slim.fully_connected( graph, diff --git a/bob/learn/tensorflow/network/vgg.py b/bob/learn/tensorflow/network/vgg.py new file mode 100644 index 0000000..f3453de --- /dev/null +++ b/bob/learn/tensorflow/network/vgg.py @@ -0,0 +1,300 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains model definitions for versions of the Oxford VGG network. + +These model definitions were introduced in the following technical report: + + Very Deep Convolutional Networks For Large-Scale Image Recognition + Karen Simonyan and Andrew Zisserman + arXiv technical report, 2015 + PDF: http://arxiv.org/pdf/1409.1556.pdf + ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf + CC-BY-4.0 + +More information can be obtained from the VGG website: +www.robots.ox.ac.uk/~vgg/research/very_deep/ + +Usage: + with slim.arg_scope(vgg.vgg_arg_scope()): + outputs, end_points = vgg.vgg_a(inputs) + + with slim.arg_scope(vgg.vgg_arg_scope()): + outputs, end_points = vgg.vgg_16(inputs) + +@@vgg_a +@@vgg_16 +@@vgg_19 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import layers +from tensorflow.contrib.framework.python.ops import arg_scope +from tensorflow.contrib.layers.python.layers import layers as layers_lib +from tensorflow.contrib.layers.python.layers import regularizers +from tensorflow.contrib.layers.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variable_scope +import tensorflow.contrib.slim as slim +import tensorflow as tf + + +def vgg_arg_scope(weight_decay=0.0005): + """Defines the VGG arg scope. + + Args: + weight_decay: The l2 regularization coefficient. + + Returns: + An arg_scope. + """ + with arg_scope( + [layers.conv2d, layers_lib.fully_connected], + activation_fn=nn_ops.relu, + weights_regularizer=regularizers.l2_regularizer(weight_decay), + biases_initializer=init_ops.zeros_initializer()): + with arg_scope([layers.conv2d], padding='SAME') as arg_sc: + return arg_sc + + +def vgg_a(inputs, + num_classes=1000, + is_training=True, + dropout_keep_prob=0.5, + spatial_squeeze=True, + scope='vgg_a'): + """Oxford Net VGG 11-Layers version A Example. + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + is_training: whether or not the model is being trained. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + """ + with variable_scope.variable_scope(scope, 'vgg_a', [inputs]) as sc: + end_points_collection = sc.original_name_scope + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with arg_scope( + [layers.conv2d, layers_lib.max_pool2d], + outputs_collections=end_points_collection): + net = layers_lib.repeat( + inputs, 1, layers.conv2d, 64, [3, 3], scope='conv1') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') + net = layers_lib.repeat(net, 1, layers.conv2d, 128, [3, 3], scope='conv2') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') + net = layers_lib.repeat(net, 2, layers.conv2d, 256, [3, 3], scope='conv3') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') + net = layers_lib.repeat(net, 2, layers.conv2d, 512, [3, 3], scope='conv4') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') + net = layers_lib.repeat(net, 2, layers.conv2d, 512, [3, 3], scope='conv5') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') + # Use conv2d instead of fully_connected layers. + net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') + net = layers_lib.dropout( + net, dropout_keep_prob, is_training=is_training, scope='dropout6') + net = layers.conv2d(net, 4096, [1, 1], scope='fc7') + net = layers_lib.dropout( + net, dropout_keep_prob, is_training=is_training, scope='dropout7') + net = layers.conv2d( + net, + num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope='fc8') + # Convert end_points_collection into a end_point dict. + end_points = utils.convert_collection_to_dict(end_points_collection) + if spatial_squeeze: + net = array_ops.squeeze(net, [1, 2], name='fc8/squeezed') + end_points[sc.name + '/fc8'] = net + return net, end_points + + +vgg_a.default_image_size = 224 + + +def vgg_16(inputs, + mode=tf.estimator.ModeKeys.TRAIN, + reuse=None, + dropout_keep_prob=0.5, + spatial_squeeze=True, + scope='vgg_16', + trainable_variables=None, + **kwargs): + """Oxford Net VGG 16-Layers version D Example. + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + is_training: whether or not the model is being trained. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + """ + + is_training = mode == tf.estimator.ModeKeys.TRAIN + end_points = dict() + with variable_scope.variable_scope(scope, 'vgg_16', [inputs]) as sc: + end_points_collection = sc.original_name_scope + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with arg_scope( + [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], + outputs_collections=end_points_collection): + + with slim.arg_scope(vgg_arg_scope()): + + name = "conv1" + net = layers_lib.repeat( + inputs, 2, layers.conv2d, 64, [3, 3], scope=name, reuse=reuse) + end_points[name] = net + net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') + + name = "conv2" + net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope=name, reuse=reuse) + end_points[name] = net + net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') + + name = "conv3" + net = layers_lib.repeat(net, 3, layers.conv2d, 256, [3, 3], scope=name, reuse=reuse) + end_points[name] = net + net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') + + name = "conv4" + net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope=name, reuse=reuse) + end_points[name] = net + net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') + + name = "conv5" + net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope=name, reuse=reuse) + end_points[name] = net + net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') + + # Tiago: Make things flat + net = layers_lib.flatten(net) + + # Use conv2d instead of fully_connected layers. + # net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6', reuse=reuse) + # net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6', reuse=reuse) + name = "fc6" + net = layers.fully_connected(net, 4096, activation_fn=tf.nn.relu, scope=name, reuse=reuse) + end_points[name] = net + + net = layers_lib.dropout( + net, dropout_keep_prob, is_training=is_training, scope='dropout6') + + # net = layers.conv2d(net, 4096, [1, 1], scope='fc7', reuse=reuse) + name = "fc7" + net = layers.fully_connected(net, 4096, activation_fn=tf.nn.relu, scope=name, reuse=reuse) + end_points[name] = net + net = layers_lib.dropout( + net, dropout_keep_prob, is_training=is_training, scope='dropout7') + + return net, end_points + + +vgg_16.default_image_size = 224 + + +def vgg_19(inputs, + num_classes=1000, + is_training=True, + dropout_keep_prob=0.5, + spatial_squeeze=True, + scope='vgg_19'): + """Oxford Net VGG 19-Layers version E Example. + + Note: All the fully_connected layers have been transformed to conv2d layers. + To use in classification mode, resize input to 224x224. + + Args: + inputs: a tensor of size [batch_size, height, width, channels]. + num_classes: number of predicted classes. + is_training: whether or not the model is being trained. + dropout_keep_prob: the probability that activations are kept in the dropout + layers during training. + spatial_squeeze: whether or not should squeeze the spatial dimensions of the + outputs. Useful to remove unnecessary dimensions for classification. + scope: Optional scope for the variables. + + Returns: + the last op containing the log predictions and end_points dict. + """ + with variable_scope.variable_scope(scope, 'vgg_19', [inputs]) as sc: + end_points_collection = sc.name + '_end_points' + # Collect outputs for conv2d, fully_connected and max_pool2d. + with arg_scope( + [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], + outputs_collections=end_points_collection): + + with slim.arg_scope(vgg_arg_scope()): + + net = layers_lib.repeat( + inputs, 2, layers.conv2d, 64, [3, 3], scope='conv1') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') + net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope='conv2') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') + net = layers_lib.repeat(net, 4, layers.conv2d, 256, [3, 3], scope='conv3') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') + net = layers_lib.repeat(net, 4, layers.conv2d, 512, [3, 3], scope='conv4') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') + net = layers_lib.repeat(net, 4, layers.conv2d, 512, [3, 3], scope='conv5') + net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') + # Use conv2d instead of fully_connected layers. + net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') + net = layers_lib.dropout( + net, dropout_keep_prob, is_training=is_training, scope='dropout6') + net = layers.conv2d(net, 4096, [1, 1], scope='fc7') + net = layers_lib.dropout( + net, dropout_keep_prob, is_training=is_training, scope='dropout7') + net = layers.conv2d( + net, + num_classes, [1, 1], + activation_fn=None, + normalizer_fn=None, + scope='fc8') + # Convert end_points_collection into a end_point dict. + end_points = utils.convert_collection_to_dict(end_points_collection) + if spatial_squeeze: + net = array_ops.squeeze(net, [1, 2], name='fc8/squeezed') + end_points[sc.name + '/fc8'] = net + return net, end_points + + +vgg_19.default_image_size = 224 + +# Alias +vgg_d = vgg_16 +vgg_e = vgg_19 diff --git a/bob/learn/tensorflow/test/test_architectures.py b/bob/learn/tensorflow/test/test_architectures.py new file mode 100644 index 0000000..4e6c532 --- /dev/null +++ b/bob/learn/tensorflow/test/test_architectures.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira + +import tensorflow as tf +from bob.learn.tensorflow.network import inception_resnet_v2, inception_resnet_v2_batch_norm,\ + inception_resnet_v1, inception_resnet_v1_batch_norm + +def test_inceptionv2(): + + # Testing WITHOUT batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v2(inputs) + assert len(tf.trainable_variables())==490 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + # Testing WITH batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v2_batch_norm(inputs) + assert len(tf.trainable_variables())==900 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + +def test_inceptionv1(): + + # Testing WITHOUT batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v1(inputs) + assert len(tf.trainable_variables())==266 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + # Testing WITH batch norm + inputs = tf.placeholder(tf.float32, shape=(1, 160, 160, 1)) + graph, _ = inception_resnet_v1_batch_norm(inputs) + assert len(tf.trainable_variables())==490 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + -- GitLab From 9013c0b8670cc67dfa0795ca5495c8f7aae41ea8 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira Date: Wed, 24 Jan 2018 10:41:41 +0100 Subject: [PATCH 2/2] Removed VGG16 Documented exponential moving average Removed comments --- bob/learn/tensorflow/estimators/Logits.py | 13 + bob/learn/tensorflow/loss/BaseLoss.py | 1 - .../tensorflow/network/InceptionResnetV2.py | 7 - bob/learn/tensorflow/network/__init__.py | 1 - bob/learn/tensorflow/network/vgg.py | 300 ------------------ 5 files changed, 13 insertions(+), 309 deletions(-) delete mode 100644 bob/learn/tensorflow/network/vgg.py diff --git a/bob/learn/tensorflow/estimators/Logits.py b/bob/learn/tensorflow/estimators/Logits.py index 28b7832..06b0398 100755 --- a/bob/learn/tensorflow/estimators/Logits.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -80,6 +80,12 @@ class Logits(estimator.Estimator): "scopes": dict({"/": "/"}), "trainable_variables": [] } + + apply_moving_averages: bool + Apply exponential moving average in the training variables and in the loss. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + By default the decay for the variable averages is 0.9999 and for the loss is 0.9 + """ def __init__(self, @@ -269,6 +275,13 @@ class LogitsCenterLoss(estimator.Estimator): "scopes": dict({"/": "/"}), "trainable_variables": [] } + + apply_moving_averages: bool + Apply exponential moving average in the training variables and in the loss. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + By default the decay for the variable averages is 0.9999 and for the loss is 0.9 + + """ diff --git a/bob/learn/tensorflow/loss/BaseLoss.py b/bob/learn/tensorflow/loss/BaseLoss.py index 881c924..6053b8b 100644 --- a/bob/learn/tensorflow/loss/BaseLoss.py +++ b/bob/learn/tensorflow/loss/BaseLoss.py @@ -30,7 +30,6 @@ def mean_cross_entropy_loss(logits, labels, add_regularization_losses=True): tf.summary.scalar('cross_entropy_loss', cross_loss) tf.add_to_collection(tf.GraphKeys.LOSSES, cross_loss) - if add_regularization_losses: regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) diff --git a/bob/learn/tensorflow/network/InceptionResnetV2.py b/bob/learn/tensorflow/network/InceptionResnetV2.py index a611a6a..b52fcc4 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV2.py +++ b/bob/learn/tensorflow/network/InceptionResnetV2.py @@ -321,13 +321,6 @@ def inception_resnet_v2(inputs, [slim.batch_norm, slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): - #with slim.arg_scope( - # [slim.conv2d, slim.fully_connected], - # weights_initializer=tf.truncated_normal_initializer(stddev=0.1), - # weights_regularizer=slim.l2_regularizer(weight_decay), - # normalizer_fn=slim.batch_norm, - # normalizer_params=batch_norm_params): - with slim.arg_scope( [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index 5ea2dc4..5aca74a 100644 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -5,7 +5,6 @@ from .MLP import mlp from .Embedding import Embedding from .InceptionResnetV2 import inception_resnet_v2, inception_resnet_v2_batch_norm from .InceptionResnetV1 import inception_resnet_v1, inception_resnet_v1_batch_norm -from .vgg import vgg_16 from . import SimpleCNN diff --git a/bob/learn/tensorflow/network/vgg.py b/bob/learn/tensorflow/network/vgg.py deleted file mode 100644 index f3453de..0000000 --- a/bob/learn/tensorflow/network/vgg.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Contains model definitions for versions of the Oxford VGG network. - -These model definitions were introduced in the following technical report: - - Very Deep Convolutional Networks For Large-Scale Image Recognition - Karen Simonyan and Andrew Zisserman - arXiv technical report, 2015 - PDF: http://arxiv.org/pdf/1409.1556.pdf - ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf - CC-BY-4.0 - -More information can be obtained from the VGG website: -www.robots.ox.ac.uk/~vgg/research/very_deep/ - -Usage: - with slim.arg_scope(vgg.vgg_arg_scope()): - outputs, end_points = vgg.vgg_a(inputs) - - with slim.arg_scope(vgg.vgg_arg_scope()): - outputs, end_points = vgg.vgg_16(inputs) - -@@vgg_a -@@vgg_16 -@@vgg_19 -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import arg_scope -from tensorflow.contrib.layers.python.layers import layers as layers_lib -from tensorflow.contrib.layers.python.layers import regularizers -from tensorflow.contrib.layers.python.layers import utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variable_scope -import tensorflow.contrib.slim as slim -import tensorflow as tf - - -def vgg_arg_scope(weight_decay=0.0005): - """Defines the VGG arg scope. - - Args: - weight_decay: The l2 regularization coefficient. - - Returns: - An arg_scope. - """ - with arg_scope( - [layers.conv2d, layers_lib.fully_connected], - activation_fn=nn_ops.relu, - weights_regularizer=regularizers.l2_regularizer(weight_decay), - biases_initializer=init_ops.zeros_initializer()): - with arg_scope([layers.conv2d], padding='SAME') as arg_sc: - return arg_sc - - -def vgg_a(inputs, - num_classes=1000, - is_training=True, - dropout_keep_prob=0.5, - spatial_squeeze=True, - scope='vgg_a'): - """Oxford Net VGG 11-Layers version A Example. - - Note: All the fully_connected layers have been transformed to conv2d layers. - To use in classification mode, resize input to 224x224. - - Args: - inputs: a tensor of size [batch_size, height, width, channels]. - num_classes: number of predicted classes. - is_training: whether or not the model is being trained. - dropout_keep_prob: the probability that activations are kept in the dropout - layers during training. - spatial_squeeze: whether or not should squeeze the spatial dimensions of the - outputs. Useful to remove unnecessary dimensions for classification. - scope: Optional scope for the variables. - - Returns: - the last op containing the log predictions and end_points dict. - """ - with variable_scope.variable_scope(scope, 'vgg_a', [inputs]) as sc: - end_points_collection = sc.original_name_scope + '_end_points' - # Collect outputs for conv2d, fully_connected and max_pool2d. - with arg_scope( - [layers.conv2d, layers_lib.max_pool2d], - outputs_collections=end_points_collection): - net = layers_lib.repeat( - inputs, 1, layers.conv2d, 64, [3, 3], scope='conv1') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') - net = layers_lib.repeat(net, 1, layers.conv2d, 128, [3, 3], scope='conv2') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') - net = layers_lib.repeat(net, 2, layers.conv2d, 256, [3, 3], scope='conv3') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') - net = layers_lib.repeat(net, 2, layers.conv2d, 512, [3, 3], scope='conv4') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') - net = layers_lib.repeat(net, 2, layers.conv2d, 512, [3, 3], scope='conv5') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') - # Use conv2d instead of fully_connected layers. - net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') - net = layers_lib.dropout( - net, dropout_keep_prob, is_training=is_training, scope='dropout6') - net = layers.conv2d(net, 4096, [1, 1], scope='fc7') - net = layers_lib.dropout( - net, dropout_keep_prob, is_training=is_training, scope='dropout7') - net = layers.conv2d( - net, - num_classes, [1, 1], - activation_fn=None, - normalizer_fn=None, - scope='fc8') - # Convert end_points_collection into a end_point dict. - end_points = utils.convert_collection_to_dict(end_points_collection) - if spatial_squeeze: - net = array_ops.squeeze(net, [1, 2], name='fc8/squeezed') - end_points[sc.name + '/fc8'] = net - return net, end_points - - -vgg_a.default_image_size = 224 - - -def vgg_16(inputs, - mode=tf.estimator.ModeKeys.TRAIN, - reuse=None, - dropout_keep_prob=0.5, - spatial_squeeze=True, - scope='vgg_16', - trainable_variables=None, - **kwargs): - """Oxford Net VGG 16-Layers version D Example. - - Note: All the fully_connected layers have been transformed to conv2d layers. - To use in classification mode, resize input to 224x224. - - Args: - inputs: a tensor of size [batch_size, height, width, channels]. - num_classes: number of predicted classes. - is_training: whether or not the model is being trained. - dropout_keep_prob: the probability that activations are kept in the dropout - layers during training. - spatial_squeeze: whether or not should squeeze the spatial dimensions of the - outputs. Useful to remove unnecessary dimensions for classification. - scope: Optional scope for the variables. - - Returns: - the last op containing the log predictions and end_points dict. - """ - - is_training = mode == tf.estimator.ModeKeys.TRAIN - end_points = dict() - with variable_scope.variable_scope(scope, 'vgg_16', [inputs]) as sc: - end_points_collection = sc.original_name_scope + '_end_points' - # Collect outputs for conv2d, fully_connected and max_pool2d. - with arg_scope( - [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], - outputs_collections=end_points_collection): - - with slim.arg_scope(vgg_arg_scope()): - - name = "conv1" - net = layers_lib.repeat( - inputs, 2, layers.conv2d, 64, [3, 3], scope=name, reuse=reuse) - end_points[name] = net - net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') - - name = "conv2" - net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope=name, reuse=reuse) - end_points[name] = net - net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') - - name = "conv3" - net = layers_lib.repeat(net, 3, layers.conv2d, 256, [3, 3], scope=name, reuse=reuse) - end_points[name] = net - net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') - - name = "conv4" - net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope=name, reuse=reuse) - end_points[name] = net - net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') - - name = "conv5" - net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope=name, reuse=reuse) - end_points[name] = net - net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') - - # Tiago: Make things flat - net = layers_lib.flatten(net) - - # Use conv2d instead of fully_connected layers. - # net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6', reuse=reuse) - # net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6', reuse=reuse) - name = "fc6" - net = layers.fully_connected(net, 4096, activation_fn=tf.nn.relu, scope=name, reuse=reuse) - end_points[name] = net - - net = layers_lib.dropout( - net, dropout_keep_prob, is_training=is_training, scope='dropout6') - - # net = layers.conv2d(net, 4096, [1, 1], scope='fc7', reuse=reuse) - name = "fc7" - net = layers.fully_connected(net, 4096, activation_fn=tf.nn.relu, scope=name, reuse=reuse) - end_points[name] = net - net = layers_lib.dropout( - net, dropout_keep_prob, is_training=is_training, scope='dropout7') - - return net, end_points - - -vgg_16.default_image_size = 224 - - -def vgg_19(inputs, - num_classes=1000, - is_training=True, - dropout_keep_prob=0.5, - spatial_squeeze=True, - scope='vgg_19'): - """Oxford Net VGG 19-Layers version E Example. - - Note: All the fully_connected layers have been transformed to conv2d layers. - To use in classification mode, resize input to 224x224. - - Args: - inputs: a tensor of size [batch_size, height, width, channels]. - num_classes: number of predicted classes. - is_training: whether or not the model is being trained. - dropout_keep_prob: the probability that activations are kept in the dropout - layers during training. - spatial_squeeze: whether or not should squeeze the spatial dimensions of the - outputs. Useful to remove unnecessary dimensions for classification. - scope: Optional scope for the variables. - - Returns: - the last op containing the log predictions and end_points dict. - """ - with variable_scope.variable_scope(scope, 'vgg_19', [inputs]) as sc: - end_points_collection = sc.name + '_end_points' - # Collect outputs for conv2d, fully_connected and max_pool2d. - with arg_scope( - [layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], - outputs_collections=end_points_collection): - - with slim.arg_scope(vgg_arg_scope()): - - net = layers_lib.repeat( - inputs, 2, layers.conv2d, 64, [3, 3], scope='conv1') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') - net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope='conv2') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') - net = layers_lib.repeat(net, 4, layers.conv2d, 256, [3, 3], scope='conv3') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') - net = layers_lib.repeat(net, 4, layers.conv2d, 512, [3, 3], scope='conv4') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool4') - net = layers_lib.repeat(net, 4, layers.conv2d, 512, [3, 3], scope='conv5') - net = layers_lib.max_pool2d(net, [2, 2], scope='pool5') - # Use conv2d instead of fully_connected layers. - net = layers.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') - net = layers_lib.dropout( - net, dropout_keep_prob, is_training=is_training, scope='dropout6') - net = layers.conv2d(net, 4096, [1, 1], scope='fc7') - net = layers_lib.dropout( - net, dropout_keep_prob, is_training=is_training, scope='dropout7') - net = layers.conv2d( - net, - num_classes, [1, 1], - activation_fn=None, - normalizer_fn=None, - scope='fc8') - # Convert end_points_collection into a end_point dict. - end_points = utils.convert_collection_to_dict(end_points_collection) - if spatial_squeeze: - net = array_ops.squeeze(net, [1, 2], name='fc8/squeezed') - end_points[sc.name + '/fc8'] = net - return net, end_points - - -vgg_19.default_image_size = 224 - -# Alias -vgg_d = vgg_16 -vgg_e = vgg_19 -- GitLab