Skip to content
Snippets Groups Projects

Gan

Closed Guillaume HEUSCH requested to merge gan into master
1 file
+ 13
6
Compare changes
  • Side-by-side
  • Inline
@@ -66,7 +66,7 @@ class Layer(object):
def variable_exist(self, var):
return var in [v.name.split("/")[0] for v in tf.global_variables()]
def batch_normalize(self, x, phase_train):
def batch_normalize(self, x, phase_train, scope=None):
"""
Batch normalization on convolutional maps.
Ref: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow/33950177
@@ -76,25 +76,32 @@ class Layer(object):
phase_train:
"""
return tf.contrib.layers.batch_norm(x, decay=1.9, updates_collections=None, epsilon=1e-5, scale=True, is_training=phase_train, scope=scope)
from tensorflow.python.ops import control_flow_ops
print scope
for v in tf.global_variables():
print v.name
name = "batch_norm_" + str(self.name)
reuse = self.variable_exist(name)
name = self.name + "/batch_norm"
reuse = self.variable_exist(self.name)
#if reuse:
#import ipdb; ipdb.set_trace();
with tf.variable_scope(name, reuse=reuse):
with tf.variable_scope(scope, reuse=reuse):
phase_train = tf.convert_to_tensor(phase_train, dtype=tf.bool)
n_out = int(x.get_shape()[-1])
self.beta = tf.get_variable(name + '_beta',
self.beta = tf.get_variable('beta',
initializer=tf.constant(0.0, shape=[n_out], dtype=x.dtype),
trainable=True,
dtype=x.dtype)
self.gamma = tf.get_variable(name + '_gamma',
self.gamma = tf.get_variable('gamma',
initializer=tf.constant(1.0, shape=[n_out], dtype=x.dtype),
trainable=True,
dtype=x.dtype)
Loading