Commit d7615e54 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix the batch normalization updates

parent d1853331
Pipeline #19541 failed with stage
in 47 minutes and 17 seconds
......@@ -62,7 +62,7 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
name = 'BN-{}'.format(number)
bn = tf.layers.batch_normalization(
conv, axis=bn_axis, training=training, fused=True, name=name)
conv, axis=bn_axis, training=training, name=name)
endpoints[name] = bn
name = 'Activation-{}'.format(number)
......@@ -75,7 +75,7 @@ def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
else:
pool = tf.layers.max_pooling2d(
inputs=bn_act, pool_size=pool_size, strides=pool_strides,
data_format=data_format, name=name)
padding='same', data_format=data_format, name=name)
endpoints[name] = pool
return pool
......@@ -91,7 +91,7 @@ def create_dense_layer(inputs, mode, endpoints, number, units):
name = 'BN-{}'.format(number + 5)
bn = tf.layers.batch_normalization(
fc, axis=1, training=training, fused=True, name=name)
fc, axis=1, training=training, name=name)
endpoints[name] = bn
name = 'Activation-{}'.format(number + 5)
......@@ -232,7 +232,7 @@ def model_fn(features, labels, mode, params=None, config=None):
# Calculate Loss (for both TRAIN and EVAL modes)
loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
# Add the regularization terms to the loss
if tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES):
if regularization_rate:
loss += regularization_rate * \
tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
......@@ -242,10 +242,11 @@ def model_fn(features, labels, mode, params=None, config=None):
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(
learning_rate=initial_learning_rate,
global_step=tf.train.get_or_create_global_step(),
global_step=global_step,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=staircase)
......@@ -253,8 +254,12 @@ def model_fn(features, labels, mode, params=None, config=None):
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=momentum)
train_op = optimizer.minimize(
loss=loss, global_step=tf.train.get_or_create_global_step())
# for batch normalization to be updated as well:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(
loss=loss, global_step=global_step)
# Log accuracy and loss
with tf.name_scope('train_metrics'):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment