Skip to content
Snippets Groups Projects
Commit b24540e2 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add moving average and transfer learning to simplecnn

parent 54c0a2af
Branches
Tags
1 merge request!50Improvements on Simplecnn
Pipeline #
import collections
import tensorflow as tf import tensorflow as tf
from .utils import is_trainable
from ..estimators import get_trainable_variables
def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
kernel_size, pool_size, pool_strides, kernel_size, pool_size, pool_strides,
add_batch_norm=False): add_batch_norm=False, trainable_variables=None):
bn_axis = 1 if data_format.lower() == 'channels_first' else 3 bn_axis = 1 if data_format.lower() == 'channels_first' else 3
training = mode == tf.estimator.ModeKeys.TRAIN training = mode == tf.estimator.ModeKeys.TRAIN
...@@ -13,19 +16,22 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, ...@@ -13,19 +16,22 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
activation = tf.nn.relu activation = tf.nn.relu
name = 'conv{}'.format(number) name = 'conv{}'.format(number)
trainable = is_trainable(name, trainable_variables)
conv = tf.layers.conv2d( conv = tf.layers.conv2d(
inputs=inputs, inputs=inputs,
filters=filters, filters=filters,
kernel_size=kernel_size, kernel_size=kernel_size,
padding="same", padding="same",
activation=activation, activation=activation,
data_format=data_format) data_format=data_format,
trainable=trainable)
endpoints[name] = conv endpoints[name] = conv
if add_batch_norm: if add_batch_norm:
name = 'bn{}'.format(number) name = 'bn{}'.format(number)
trainable = is_trainable(name, trainable_variables)
bn = tf.layers.batch_normalization( bn = tf.layers.batch_normalization(
conv, axis=bn_axis, training=training) conv, axis=bn_axis, training=training, trainable=trainable)
endpoints[name] = bn endpoints[name] = bn
name = 'activation{}'.format(number) name = 'activation{}'.format(number)
...@@ -44,7 +50,8 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, ...@@ -44,7 +50,8 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
def base_architecture(input_layer, mode, kernerl_size, data_format, def base_architecture(input_layer, mode, kernerl_size, data_format,
add_batch_norm=False, **kwargs): add_batch_norm=False, trainable_variables=None,
**kwargs):
training = mode == tf.estimator.ModeKeys.TRAIN training = mode == tf.estimator.ModeKeys.TRAIN
# Keep track of all the endpoints # Keep track of all the endpoints
endpoints = {} endpoints = {}
...@@ -56,7 +63,8 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, ...@@ -56,7 +63,8 @@ def base_architecture(input_layer, mode, kernerl_size, data_format,
pool1 = create_conv_layer( pool1 = create_conv_layer(
inputs=input_layer, mode=mode, data_format=data_format, inputs=input_layer, mode=mode, data_format=data_format,
endpoints=endpoints, number=1, filters=32, kernel_size=kernerl_size, endpoints=endpoints, number=1, filters=32, kernel_size=kernerl_size,
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm) pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm,
trainable_variables=trainable_variables)
# Convolutional Layer #2 # Convolutional Layer #2
# Computes 64 features using a kernerl_size filter. # Computes 64 features using a kernerl_size filter.
...@@ -64,11 +72,11 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, ...@@ -64,11 +72,11 @@ def base_architecture(input_layer, mode, kernerl_size, data_format,
pool2 = create_conv_layer( pool2 = create_conv_layer(
inputs=pool1, mode=mode, data_format=data_format, inputs=pool1, mode=mode, data_format=data_format,
endpoints=endpoints, number=2, filters=64, kernel_size=kernerl_size, endpoints=endpoints, number=2, filters=64, kernel_size=kernerl_size,
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm) pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm,
trainable_variables=trainable_variables)
# Flatten tensor into a batch of vectors # Flatten tensor into a batch of vectors
# TODO: use tf.layers.flatten in tensorflow 1.4 and above pool2_flat = tf.layers.flatten(pool2)
pool2_flat = tf.contrib.layers.flatten(pool2)
endpoints['pool2_flat'] = pool2_flat endpoints['pool2_flat'] = pool2_flat
# Dense Layer # Dense Layer
...@@ -78,14 +86,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, ...@@ -78,14 +86,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format,
else: else:
activation = tf.nn.relu activation = tf.nn.relu
name = 'dense'
trainable = is_trainable(name, trainable_variables)
dense = tf.layers.dense( dense = tf.layers.dense(
inputs=pool2_flat, units=1024, activation=activation) inputs=pool2_flat, units=1024, activation=activation,
endpoints['dense'] = dense trainable=trainable)
endpoints[name] = dense
if add_batch_norm: if add_batch_norm:
name = 'bn{}'.format(3) name = 'bn{}'.format(3)
trainable = is_trainable(name, trainable_variables)
bn = tf.layers.batch_normalization( bn = tf.layers.batch_normalization(
dense, axis=1, training=training) dense, axis=1, training=training, trainable=trainable)
endpoints[name] = bn endpoints[name] = bn
name = 'activation{}'.format(3) name = 'activation{}'.format(3)
...@@ -109,18 +121,23 @@ def architecture(input_layer, ...@@ -109,18 +121,23 @@ def architecture(input_layer,
data_format='channels_last', data_format='channels_last',
reuse=False, reuse=False,
add_batch_norm=False, add_batch_norm=False,
trainable_variables=None,
**kwargs): **kwargs):
with tf.variable_scope('SimpleCNN', reuse=reuse): with tf.variable_scope('SimpleCNN', reuse=reuse):
dropout, endpoints = base_architecture( dropout, endpoints = base_architecture(
input_layer, mode, kernerl_size, data_format, input_layer, mode, kernerl_size, data_format,
add_batch_norm=add_batch_norm) add_batch_norm=add_batch_norm,
trainable_variables=trainable_variables)
# Logits layer # Logits layer
# Input Tensor Shape: [batch_size, 1024] # Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, n_classes] # Output Tensor Shape: [batch_size, n_classes]
logits = tf.layers.dense(inputs=dropout, units=n_classes) name = 'logits'
endpoints['logits'] = logits trainable = is_trainable(name, trainable_variables)
logits = tf.layers.dense(inputs=dropout, units=n_classes,
trainable=trainable)
endpoints[name] = logits
return logits, endpoints return logits, endpoints
...@@ -133,17 +150,28 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -133,17 +150,28 @@ def model_fn(features, labels, mode, params=None, config=None):
params = params or {} params = params or {}
learning_rate = params.get('learning_rate', 1e-5) learning_rate = params.get('learning_rate', 1e-5)
apply_moving_averages = params.get('apply_moving_averages', False) apply_moving_averages = params.get('apply_moving_averages', False)
extra_checkpoint = params.get('extra_checkpoint', None)
trainable_variables = get_trainable_variables(extra_checkpoint)
loss_weights = params.get('loss_weights', 1.0)
arch_kwargs = { arch_kwargs = {
'kernerl_size': params.get('kernerl_size', None), 'kernerl_size': params.get('kernerl_size', None),
'n_classes': params.get('n_classes', None), 'n_classes': params.get('n_classes', None),
'data_format': params.get('data_format', None), 'data_format': params.get('data_format', None),
'add_batch_norm': params.get('add_batch_norm', None) 'add_batch_norm': params.get('add_batch_norm', None),
'trainable_variables': trainable_variables,
} }
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None} arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
logits, _ = architecture(data, mode, **arch_kwargs) logits, _ = architecture(data, mode, **arch_kwargs)
# restore the model from an extra_checkpoint
if extra_checkpoint is not None and mode == tf.estimator.ModeKeys.TRAIN:
tf.train.init_from_checkpoint(
ckpt_dir_or_file=extra_checkpoint["checkpoint_path"],
assignment_map=extra_checkpoint["scopes"],
)
predictions = { predictions = {
# Generate predictions (for PREDICT and EVAL mode) # Generate predictions (for PREDICT and EVAL mode)
"classes": tf.argmax(input=logits, axis=1), "classes": tf.argmax(input=logits, axis=1),
...@@ -178,9 +206,13 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -178,9 +206,13 @@ def model_fn(features, labels, mode, params=None, config=None):
with tf.control_dependencies([variable_averages_op] + update_ops): with tf.control_dependencies([variable_averages_op] + update_ops):
# convert weights of per sample to weights per class
if isinstance(loss_weights, collections.Iterable):
loss_weights = tf.gather(loss_weights, labels)
# Calculate Loss (for both TRAIN and EVAL modes) # Calculate Loss (for both TRAIN and EVAL modes)
loss = tf.losses.sparse_softmax_cross_entropy( loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels) logits=logits, labels=labels, weights=loss_weights)
if apply_moving_averages and mode == tf.estimator.ModeKeys.TRAIN: if apply_moving_averages and mode == tf.estimator.ModeKeys.TRAIN:
# Compute the moving average of all individual losses and the total # Compute the moving average of all individual losses and the total
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment