Implementing batch normalization

parent ecd82cb4
...@@ -69,12 +69,12 @@ class Layer(object): ...@@ -69,12 +69,12 @@ class Layer(object):
#with tf.variable_scope(name): #with tf.variable_scope(name):
phase_train = tf.convert_to_tensor(phase_train, dtype=tf.bool) phase_train = tf.convert_to_tensor(phase_train, dtype=tf.bool)
n_out = int(x.get_shape()[-1]) n_out = int(x.get_shape()[-1])
self.beta = tf.Variable(tf.constant(0.0, shape=[n_out], dtype=x.dtype), self.beta = tf.get_variable(name + '_beta',
name=name + '_beta', initializer=tf.constant(0.0, shape=[n_out], dtype=x.dtype),
trainable=True, trainable=True,
dtype=x.dtype) dtype=x.dtype)
self.gamma = tf.Variable(tf.constant(1.0, shape=[n_out], dtype=x.dtype), self.gamma = tf.get_variable(name + '_gamma',
name=name + '_gamma', initializer=tf.constant(1.0, shape=[n_out], dtype=x.dtype),
trainable=True, trainable=True,
dtype=x.dtype) dtype=x.dtype)
......
...@@ -244,7 +244,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -244,7 +244,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session) self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
session.run(self.sequence_net[k].b) session.run(self.sequence_net[k].b)
if self.sequence_net[k].batch_norm:
self.sequence_net[k].beta.assign(hdf5.read(self.sequence_net[k].beta.name)).eval(session=session)
self.sequence_net[k].gamma.assign(hdf5.read(self.sequence_net[k].gamma.name)).eval(session=session)
hdf5.cd("..") hdf5.cd("..")
......
...@@ -54,7 +54,7 @@ def validate_network(validation_data, validation_labels, directory): ...@@ -54,7 +54,7 @@ def validate_network(validation_data, validation_labels, directory):
path = os.path.join(directory, "model.hdf5") path = os.path.join(directory, "model.hdf5")
#path = os.path.join(directory, "model.ckp") #path = os.path.join(directory, "model.ckp")
#scratch = SequenceNetwork(default_feature_layer="fc1") #scratch = SequenceNetwork(default_feature_layer="fc1")
scratch = SequenceNetwork() scratch = SequenceNetwork(default_feature_layer="fc1")
#scratch.load_original(session, os.path.join(directory, "model.ckp")) #scratch.load_original(session, os.path.join(directory, "model.ckp"))
scratch.load(bob.io.base.HDF5File(path), scratch.load(bob.io.base.HDF5File(path),
shape=validation_shape, session=session) shape=validation_shape, session=session)
...@@ -94,8 +94,10 @@ def test_cnn_trainer_scratch(): ...@@ -94,8 +94,10 @@ def test_cnn_trainer_scratch():
prefetch=False, prefetch=False,
temp_dir=directory) temp_dir=directory)
trainer.train(train_data_shuffler) trainer.train(train_data_shuffler)
del trainer
del scratch
import ipdb; ipdb.set_trace(); #import ipdb; ipdb.set_trace();
accuracy = validate_network(validation_data, validation_labels, directory) accuracy = validate_network(validation_data, validation_labels, directory)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment