diff --git a/bob/learn/tensorflow/network/MLP.py b/bob/learn/tensorflow/network/MLP.py index 1af525f12bba51beefa9a781e97e2b65d1969483..a1b1a79d443cc18c839fe60cadf1c63c1021efa5 100644 --- a/bob/learn/tensorflow/network/MLP.py +++ b/bob/learn/tensorflow/network/MLP.py @@ -3,6 +3,8 @@ # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> import tensorflow as tf +from bob.learn.tensorflow.network.utils import is_trainable +slim = tf.contrib.slim def mlp(inputs, @@ -32,10 +34,9 @@ def mlp(inputs, output_activation: Activation of the output layer. If you set to `None`, the activation will be linear - seed: + seed: """ - slim = tf.contrib.slim initializer = tf.contrib.layers.xavier_initializer( uniform=False, dtype=tf.float32, seed=seed) @@ -58,3 +59,53 @@ def mlp(inputs, scope='fc_output') return graph + + +def mlp_with_batchnorm_and_dropout(inputs, + fully_connected_layers, + mode=tf.estimator.ModeKeys.TRAIN, + trainable_variables=None, + **kwargs): + + if trainable_variables is not None: + raise ValueError("The batch_norm layers selectable training is not implemented!") + + end_points = {} + net = slim.flatten(inputs) + + weight_decay = 1e-5 + dropout_keep_prob = 0.5 + batch_norm_params = { + # Decay for the moving averages. + 'decay': 0.995, + # epsilon to prevent 0s in variance. + 'epsilon': 0.001, + # force in-place updates of mean and variance estimates + 'updates_collections': None, + 'is_training': (mode == tf.estimator.ModeKeys.TRAIN), + } + + with slim.arg_scope( + [slim.fully_connected], + weights_initializer=tf.truncated_normal_initializer(stddev=0.1), + weights_regularizer=slim.l2_regularizer(weight_decay), + normalizer_fn=slim.batch_norm, + normalizer_params=batch_norm_params + ), tf.name_scope('MLP'): + + # hidden layers + for i, n in enumerate(fully_connected_layers): + name = 'fc_{:0d}'.format(i) + trainable = is_trainable(name, trainable_variables, mode=mode) + net = slim.fully_connected(net, n, scope=name, trainable=trainable) + end_points[name] = net + + name = 'dropout_{:0d}'.format(i) + net = slim.dropout( + net, + dropout_keep_prob, + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + scope=name) + end_points[name] = net + + return net, end_points