Skip to content
Snippets Groups Projects
Commit 90b4835f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add batch norm to simplecnn

parent 7ac847c8
No related branches found
No related tags found
1 merge request!47Many changes
Pipeline #
import tensorflow as tf import tensorflow as tf
def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs): def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
kernel_size, pool_size, pool_strides,
add_batch_norm=False):
bn_axis = 1 if data_format.lower() == 'channels_first' else 3
training = mode == tf.estimator.ModeKeys.TRAIN
if add_batch_norm:
activation = None
else:
activation = tf.nn.relu
name = 'conv{}'.format(number)
conv = tf.layers.conv2d(
inputs=inputs,
filters=filters,
kernel_size=kernel_size,
padding="same",
activation=activation,
data_format=data_format)
endpoints[name] = conv
if add_batch_norm:
name = 'bn{}'.format(number)
bn = tf.layers.batch_normalization(
conv, axis=bn_axis, training=training)
endpoints[name] = bn
name = 'activation{}'.format(number)
bn_act = tf.nn.relu(bn)
endpoints[name] = bn_act
else:
bn_act = conv
name = 'pool{}'.format(number)
pool = tf.layers.max_pooling2d(
inputs=bn_act, pool_size=pool_size, strides=pool_strides,
padding='same', data_format=data_format)
endpoints[name] = pool
return pool
def base_architecture(input_layer, mode, kernerl_size, data_format,
add_batch_norm=False, **kwargs):
training = mode == tf.estimator.ModeKeys.TRAIN
# Keep track of all the endpoints # Keep track of all the endpoints
endpoints = {} endpoints = {}
...@@ -9,38 +53,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs): ...@@ -9,38 +53,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
# Computes 32 features using a kernerl_size filter with ReLU # Computes 32 features using a kernerl_size filter with ReLU
# activation. # activation.
# Padding is added to preserve width and height. # Padding is added to preserve width and height.
conv1 = tf.layers.conv2d( pool1 = create_conv_layer(
inputs=input_layer, inputs=input_layer, mode=mode, data_format=data_format,
filters=32, endpoints=endpoints, number=1, filters=32, kernel_size=kernerl_size,
kernel_size=kernerl_size, pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm)
padding="same",
activation=tf.nn.relu,
data_format=data_format)
endpoints['conv1'] = conv1
# Pooling Layer #1
# First max pooling layer with a 2x2 filter and stride of 2
pool1 = tf.layers.max_pooling2d(
inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format)
endpoints['pool1'] = pool1
# Convolutional Layer #2 # Convolutional Layer #2
# Computes 64 features using a kernerl_size filter. # Computes 64 features using a kernerl_size filter.
# Padding is added to preserve width and height. # Padding is added to preserve width and height.
conv2 = tf.layers.conv2d( pool2 = create_conv_layer(
inputs=pool1, inputs=pool1, mode=mode, data_format=data_format,
filters=64, endpoints=endpoints, number=2, filters=64, kernel_size=kernerl_size,
kernel_size=kernerl_size, pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm)
padding="same",
activation=tf.nn.relu,
data_format=data_format)
endpoints['conv2'] = conv2
# Pooling Layer #2
# Second max pooling layer with a 2x2 filter and stride of 2
pool2 = tf.layers.max_pooling2d(
inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format)
endpoints['pool2'] = pool2
# Flatten tensor into a batch of vectors # Flatten tensor into a batch of vectors
# TODO: use tf.layers.flatten in tensorflow 1.4 and above # TODO: use tf.layers.flatten in tensorflow 1.4 and above
...@@ -49,13 +73,30 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs): ...@@ -49,13 +73,30 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
# Dense Layer # Dense Layer
# Densely connected layer with 1024 neurons # Densely connected layer with 1024 neurons
if add_batch_norm:
activation = None
else:
activation = tf.nn.relu
dense = tf.layers.dense( dense = tf.layers.dense(
inputs=pool2_flat, units=1024, activation=tf.nn.relu) inputs=pool2_flat, units=1024, activation=activation)
endpoints['dense'] = dense endpoints['dense'] = dense
if add_batch_norm:
name = 'bn{}'.format(3)
bn = tf.layers.batch_normalization(
dense, axis=1, training=training)
endpoints[name] = bn
name = 'activation{}'.format(3)
bn_act = tf.nn.relu(bn)
endpoints[name] = bn_act
else:
bn_act = dense
# Add dropout operation; 0.6 probability that element will be kept # Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout( dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) inputs=bn_act, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
endpoints['dropout'] = dropout endpoints['dropout'] = dropout
return dropout, endpoints return dropout, endpoints
...@@ -67,12 +108,14 @@ def architecture(input_layer, ...@@ -67,12 +108,14 @@ def architecture(input_layer,
n_classes=2, n_classes=2,
data_format='channels_last', data_format='channels_last',
reuse=False, reuse=False,
add_batch_norm=False,
**kwargs): **kwargs):
with tf.variable_scope('SimpleCNN', reuse=reuse): with tf.variable_scope('SimpleCNN', reuse=reuse):
dropout, endpoints = base_architecture(input_layer, mode, kernerl_size, dropout, endpoints = base_architecture(
data_format) input_layer, mode, kernerl_size, data_format,
add_batch_norm=add_batch_norm)
# Logits layer # Logits layer
# Input Tensor Shape: [batch_size, 1024] # Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, n_classes] # Output Tensor Shape: [batch_size, n_classes]
...@@ -94,6 +137,7 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -94,6 +137,7 @@ def model_fn(features, labels, mode, params=None, config=None):
'kernerl_size': params.get('kernerl_size', None), 'kernerl_size': params.get('kernerl_size', None),
'n_classes': params.get('n_classes', None), 'n_classes': params.get('n_classes', None),
'data_format': params.get('data_format', None), 'data_format': params.get('data_format', None),
'add_batch_norm': params.get('add_batch_norm', None)
} }
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None} arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
...@@ -118,10 +162,14 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -118,10 +162,14 @@ def model_fn(features, labels, mode, params=None, config=None):
# Configure the training op # Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.GradientDescentOptimizer( optimizer = tf.train.GradientDescentOptimizer(
learning_rate=learning_rate) learning_rate=learning_rate)
train_op = optimizer.minimize( # for batch normalization to be updated as well:
loss=loss, global_step=tf.train.get_or_create_global_step()) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(
loss=loss, global_step=global_step)
# Log accuracy and loss # Log accuracy and loss
with tf.name_scope('train_metrics'): with tf.name_scope('train_metrics'):
tf.summary.scalar('accuracy', accuracy[1]) tf.summary.scalar('accuracy', accuracy[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment