Commit b24540e2 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add moving average and transfer learning to simplecnn

parent 54c0a2af
Pipeline #20377 passed with stage
in 44 minutes and 29 seconds
import collections
import tensorflow as tf
from .utils import is_trainable
from ..estimators import get_trainable_variables
def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
kernel_size, pool_size, pool_strides,
add_batch_norm=False):
add_batch_norm=False, trainable_variables=None):
bn_axis = 1 if data_format.lower() == 'channels_first' else 3
training = mode == tf.estimator.ModeKeys.TRAIN
......@@ -13,19 +16,22 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
activation = tf.nn.relu
name = 'conv{}'.format(number)
trainable = is_trainable(name, trainable_variables)
conv = tf.layers.conv2d(
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
padding="same",
activation=activation,
data_format=data_format)
data_format=data_format,
trainable=trainable)
endpoints[name] = conv
if add_batch_norm:
name = 'bn{}'.format(number)
trainable = is_trainable(name, trainable_variables)
bn = tf.layers.batch_normalization(
conv, axis=bn_axis, training=training)
conv, axis=bn_axis, training=training, trainable=trainable)
endpoints[name] = bn
name = 'activation{}'.format(number)
......@@ -44,7 +50,8 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
def base_architecture(input_layer, mode, kernerl_size, data_format,
add_batch_norm=False, **kwargs):
add_batch_norm=False, trainable_variables=None,
**kwargs):
training = mode == tf.estimator.ModeKeys.TRAIN
# Keep track of all the endpoints
endpoints = {}
......@@ -56,7 +63,8 @@ def base_architecture(input_layer, mode, kernerl_size, data_format,
pool1 = create_conv_layer(
inputs=input_layer, mode=mode, data_format=data_format,
endpoints=endpoints, number=1, filters=32, kernel_size=kernerl_size,
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm)
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm,
trainable_variables=trainable_variables)
# Convolutional Layer #2
# Computes 64 features using a kernerl_size filter.
......@@ -64,11 +72,11 @@ def base_architecture(input_layer, mode, kernerl_size, data_format,
pool2 = create_conv_layer(
inputs=pool1, mode=mode, data_format=data_format,
endpoints=endpoints, number=2, filters=64, kernel_size=kernerl_size,
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm)
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm,
trainable_variables=trainable_variables)
# Flatten tensor into a batch of vectors
# TODO: use tf.layers.flatten in tensorflow 1.4 and above
pool2_flat = tf.contrib.layers.flatten(pool2)
pool2_flat = tf.layers.flatten(pool2)
endpoints['pool2_flat'] = pool2_flat
# Dense Layer
......@@ -78,14 +86,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format,
else:
activation = tf.nn.relu
name = 'dense'
trainable = is_trainable(name, trainable_variables)
dense = tf.layers.dense(
inputs=pool2_flat, units=1024, activation=activation)
endpoints['dense'] = dense
inputs=pool2_flat, units=1024, activation=activation,
trainable=trainable)
endpoints[name] = dense
if add_batch_norm:
name = 'bn{}'.format(3)
trainable = is_trainable(name, trainable_variables)
bn = tf.layers.batch_normalization(
dense, axis=1, training=training)
dense, axis=1, training=training, trainable=trainable)
endpoints[name] = bn
name = 'activation{}'.format(3)
......@@ -109,18 +121,23 @@ def architecture(input_layer,
data_format='channels_last',
reuse=False,
add_batch_norm=False,
trainable_variables=None,
**kwargs):
with tf.variable_scope('SimpleCNN', reuse=reuse):
dropout, endpoints = base_architecture(
input_layer, mode, kernerl_size, data_format,
add_batch_norm=add_batch_norm)
add_batch_norm=add_batch_norm,
trainable_variables=trainable_variables)
# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, n_classes]
logits = tf.layers.dense(inputs=dropout, units=n_classes)
endpoints['logits'] = logits
name = 'logits'
trainable = is_trainable(name, trainable_variables)
logits = tf.layers.dense(inputs=dropout, units=n_classes,
trainable=trainable)
endpoints[name] = logits
return logits, endpoints
......@@ -133,17 +150,28 @@ def model_fn(features, labels, mode, params=None, config=None):
params = params or {}
learning_rate = params.get('learning_rate', 1e-5)
apply_moving_averages = params.get('apply_moving_averages', False)
extra_checkpoint = params.get('extra_checkpoint', None)
trainable_variables = get_trainable_variables(extra_checkpoint)
loss_weights = params.get('loss_weights', 1.0)
arch_kwargs = {
'kernerl_size': params.get('kernerl_size', None),
'n_classes': params.get('n_classes', None),
'data_format': params.get('data_format', None),
'add_batch_norm': params.get('add_batch_norm', None)
'add_batch_norm': params.get('add_batch_norm', None),
'trainable_variables': trainable_variables,
}
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
logits, _ = architecture(data, mode, **arch_kwargs)
# restore the model from an extra_checkpoint
if extra_checkpoint is not None and mode == tf.estimator.ModeKeys.TRAIN:
tf.train.init_from_checkpoint(
ckpt_dir_or_file=extra_checkpoint["checkpoint_path"],
assignment_map=extra_checkpoint["scopes"],
)
predictions = {
# Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1),
......@@ -178,9 +206,13 @@ def model_fn(features, labels, mode, params=None, config=None):
with tf.control_dependencies([variable_averages_op] + update_ops):
# convert weights of per sample to weights per class
if isinstance(loss_weights, collections.Iterable):
loss_weights = tf.gather(loss_weights, labels)
# Calculate Loss (for both TRAIN and EVAL modes)
loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
logits=logits, labels=labels, weights=loss_weights)
if apply_moving_averages and mode == tf.estimator.ModeKeys.TRAIN:
# Compute the moving average of all individual losses and the total
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment