diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py index 0bd86450228ae735a28da3caff7f7ab4b3e3a757..2be23869113ace75058f700d5e98f6d3b074e624 100644 --- a/bob/learn/tensorflow/network/SimpleCNN.py +++ b/bob/learn/tensorflow/network/SimpleCNN.py @@ -1,9 +1,12 @@ +import collections import tensorflow as tf +from .utils import is_trainable +from ..estimators import get_trainable_variables def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, kernel_size, pool_size, pool_strides, - add_batch_norm=False): + add_batch_norm=False, trainable_variables=None): bn_axis = 1 if data_format.lower() == 'channels_first' else 3 training = mode == tf.estimator.ModeKeys.TRAIN @@ -13,19 +16,22 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, activation = tf.nn.relu name = 'conv{}'.format(number) + trainable = is_trainable(name, trainable_variables) conv = tf.layers.conv2d( inputs=inputs, filters=filters, kernel_size=kernel_size, padding="same", activation=activation, - data_format=data_format) + data_format=data_format, + trainable=trainable) endpoints[name] = conv if add_batch_norm: name = 'bn{}'.format(number) + trainable = is_trainable(name, trainable_variables) bn = tf.layers.batch_normalization( - conv, axis=bn_axis, training=training) + conv, axis=bn_axis, training=training, trainable=trainable) endpoints[name] = bn name = 'activation{}'.format(number) @@ -44,7 +50,8 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, def base_architecture(input_layer, mode, kernerl_size, data_format, - add_batch_norm=False, **kwargs): + add_batch_norm=False, trainable_variables=None, + **kwargs): training = mode == tf.estimator.ModeKeys.TRAIN # Keep track of all the endpoints endpoints = {} @@ -56,7 +63,8 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, pool1 = create_conv_layer( inputs=input_layer, mode=mode, data_format=data_format, endpoints=endpoints, number=1, filters=32, kernel_size=kernerl_size, - pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm) + pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm, + trainable_variables=trainable_variables) # Convolutional Layer #2 # Computes 64 features using a kernerl_size filter. @@ -64,11 +72,11 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, pool2 = create_conv_layer( inputs=pool1, mode=mode, data_format=data_format, endpoints=endpoints, number=2, filters=64, kernel_size=kernerl_size, - pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm) + pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm, + trainable_variables=trainable_variables) # Flatten tensor into a batch of vectors - # TODO: use tf.layers.flatten in tensorflow 1.4 and above - pool2_flat = tf.contrib.layers.flatten(pool2) + pool2_flat = tf.layers.flatten(pool2) endpoints['pool2_flat'] = pool2_flat # Dense Layer @@ -78,14 +86,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, else: activation = tf.nn.relu + name = 'dense' + trainable = is_trainable(name, trainable_variables) dense = tf.layers.dense( - inputs=pool2_flat, units=1024, activation=activation) - endpoints['dense'] = dense + inputs=pool2_flat, units=1024, activation=activation, + trainable=trainable) + endpoints[name] = dense if add_batch_norm: name = 'bn{}'.format(3) + trainable = is_trainable(name, trainable_variables) bn = tf.layers.batch_normalization( - dense, axis=1, training=training) + dense, axis=1, training=training, trainable=trainable) endpoints[name] = bn name = 'activation{}'.format(3) @@ -109,18 +121,23 @@ def architecture(input_layer, data_format='channels_last', reuse=False, add_batch_norm=False, + trainable_variables=None, **kwargs): with tf.variable_scope('SimpleCNN', reuse=reuse): dropout, endpoints = base_architecture( input_layer, mode, kernerl_size, data_format, - add_batch_norm=add_batch_norm) + add_batch_norm=add_batch_norm, + trainable_variables=trainable_variables) # Logits layer # Input Tensor Shape: [batch_size, 1024] # Output Tensor Shape: [batch_size, n_classes] - logits = tf.layers.dense(inputs=dropout, units=n_classes) - endpoints['logits'] = logits + name = 'logits' + trainable = is_trainable(name, trainable_variables) + logits = tf.layers.dense(inputs=dropout, units=n_classes, + trainable=trainable) + endpoints[name] = logits return logits, endpoints @@ -133,17 +150,28 @@ def model_fn(features, labels, mode, params=None, config=None): params = params or {} learning_rate = params.get('learning_rate', 1e-5) apply_moving_averages = params.get('apply_moving_averages', False) + extra_checkpoint = params.get('extra_checkpoint', None) + trainable_variables = get_trainable_variables(extra_checkpoint) + loss_weights = params.get('loss_weights', 1.0) arch_kwargs = { 'kernerl_size': params.get('kernerl_size', None), 'n_classes': params.get('n_classes', None), 'data_format': params.get('data_format', None), - 'add_batch_norm': params.get('add_batch_norm', None) + 'add_batch_norm': params.get('add_batch_norm', None), + 'trainable_variables': trainable_variables, } arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None} logits, _ = architecture(data, mode, **arch_kwargs) + # restore the model from an extra_checkpoint + if extra_checkpoint is not None and mode == tf.estimator.ModeKeys.TRAIN: + tf.train.init_from_checkpoint( + ckpt_dir_or_file=extra_checkpoint["checkpoint_path"], + assignment_map=extra_checkpoint["scopes"], + ) + predictions = { # Generate predictions (for PREDICT and EVAL mode) "classes": tf.argmax(input=logits, axis=1), @@ -178,9 +206,13 @@ def model_fn(features, labels, mode, params=None, config=None): with tf.control_dependencies([variable_averages_op] + update_ops): + # convert weights of per sample to weights per class + if isinstance(loss_weights, collections.Iterable): + loss_weights = tf.gather(loss_weights, labels) + # Calculate Loss (for both TRAIN and EVAL modes) loss = tf.losses.sparse_softmax_cross_entropy( - logits=logits, labels=labels) + logits=logits, labels=labels, weights=loss_weights) if apply_moving_averages and mode == tf.estimator.ModeKeys.TRAIN: # Compute the moving average of all individual losses and the total