Commit d7615e54 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix the batch normalization updates

parent d1853331
Pipeline #19541 failed with stage
in 47 minutes and 17 seconds
...@@ -62,7 +62,7 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, ...@@ -62,7 +62,7 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
name = 'BN-{}'.format(number) name = 'BN-{}'.format(number)
bn = tf.layers.batch_normalization( bn = tf.layers.batch_normalization(
conv, axis=bn_axis, training=training, fused=True, name=name) conv, axis=bn_axis, training=training, name=name)
endpoints[name] = bn endpoints[name] = bn
name = 'Activation-{}'.format(number) name = 'Activation-{}'.format(number)
...@@ -75,7 +75,7 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters, ...@@ -75,7 +75,7 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
else: else:
pool = tf.layers.max_pooling2d( pool = tf.layers.max_pooling2d(
inputs=bn_act, pool_size=pool_size, strides=pool_strides, inputs=bn_act, pool_size=pool_size, strides=pool_strides,
data_format=data_format, name=name) padding='same', data_format=data_format, name=name)
endpoints[name] = pool endpoints[name] = pool
return pool return pool
...@@ -91,7 +91,7 @@ def create_dense_layer(inputs, mode, endpoints, number, units): ...@@ -91,7 +91,7 @@ def create_dense_layer(inputs, mode, endpoints, number, units):
name = 'BN-{}'.format(number + 5) name = 'BN-{}'.format(number + 5)
bn = tf.layers.batch_normalization( bn = tf.layers.batch_normalization(
fc, axis=1, training=training, fused=True, name=name) fc, axis=1, training=training, name=name)
endpoints[name] = bn endpoints[name] = bn
name = 'Activation-{}'.format(number + 5) name = 'Activation-{}'.format(number + 5)
...@@ -232,7 +232,7 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -232,7 +232,7 @@ def model_fn(features, labels, mode, params=None, config=None):
# Calculate Loss (for both TRAIN and EVAL modes) # Calculate Loss (for both TRAIN and EVAL modes)
loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels) loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
# Add the regularization terms to the loss # Add the regularization terms to the loss
if tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES): if regularization_rate:
loss += regularization_rate * \ loss += regularization_rate * \
tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
...@@ -242,10 +242,11 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -242,10 +242,11 @@ def model_fn(features, labels, mode, params=None, config=None):
# Configure the training op # Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.train.exponential_decay( learning_rate = tf.train.exponential_decay(
learning_rate=initial_learning_rate, learning_rate=initial_learning_rate,
global_step=tf.train.get_or_create_global_step(), global_step=global_step,
decay_steps=decay_steps, decay_steps=decay_steps,
decay_rate=decay_rate, decay_rate=decay_rate,
staircase=staircase) staircase=staircase)
...@@ -253,8 +254,12 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -253,8 +254,12 @@ def model_fn(features, labels, mode, params=None, config=None):
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum) momentum=momentum)
train_op = optimizer.minimize(
loss=loss, global_step=tf.train.get_or_create_global_step()) # for batch normalization to be updated as well:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(
loss=loss, global_step=global_step)
# Log accuracy and loss # Log accuracy and loss
with tf.name_scope('train_metrics'): with tf.name_scope('train_metrics'):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment