Skip to content
Snippets Groups Projects
Commit 136a436c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Setting up batch normalization

parent 6fe1103a
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ from .Layer import Layer ...@@ -8,6 +8,7 @@ from .Layer import Layer
from operator import mul from operator import mul
from bob.learn.tensorflow.initialization import Xavier from bob.learn.tensorflow.initialization import Xavier
from bob.learn.tensorflow.initialization import Constant from bob.learn.tensorflow.initialization import Constant
import numpy
class FullyConnected(Layer): class FullyConnected(Layer):
...@@ -68,8 +69,7 @@ class FullyConnected(Layer): ...@@ -68,8 +69,7 @@ class FullyConnected(Layer):
if len(self.input_layer.get_shape()) == 4: if len(self.input_layer.get_shape()) == 4:
shape = self.input_layer.get_shape().as_list() shape = self.input_layer.get_shape().as_list()
#fc = tf.reshape(self.input_layer, [shape[0], shape[1] * shape[2] * shape[3]]) fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])])
fc = tf.reshape(self.input_layer, [-1, shape[1] * shape[2] * shape[3]])
else: else:
fc = self.input_layer fc = self.input_layer
......
...@@ -21,5 +21,5 @@ class InputLayer(Layer): ...@@ -21,5 +21,5 @@ class InputLayer(Layer):
def create_variables(self, input_layer): def create_variables(self, input_layer):
return return
def get_graph(self): def get_graph(self, training_phase=True):
return self.original_layer return self.original_layer
...@@ -61,16 +61,21 @@ class Layer(object): ...@@ -61,16 +61,21 @@ class Layer(object):
""" """
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
name = 'batch_norm' name = "batch_norm"
with tf.variable_scope(name): with tf.variable_scope(name):
phase_train = tf.convert_to_tensor(phase_train, dtype=tf.bool) phase_train = tf.convert_to_tensor(phase_train, dtype=tf.bool)
n_out = int(x.get_shape()[3]) n_out = int(x.get_shape()[-1])
beta = tf.Variable(tf.constant(0.0, shape=[n_out], dtype=x.dtype), beta = tf.Variable(tf.constant(0.0, shape=[n_out], dtype=x.dtype),
name=name + '/beta', trainable=True, dtype=x.dtype) name=name + '/beta', trainable=True, dtype=x.dtype)
gamma = tf.Variable(tf.constant(1.0, shape=[n_out], dtype=x.dtype), gamma = tf.Variable(tf.constant(1.0, shape=[n_out], dtype=x.dtype),
name=name + '/gamma', trainable=True, dtype=x.dtype) name=name + '/gamma', trainable=True, dtype=x.dtype)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') # If signal
#if len(x.get_shape()) == 2:
# batch_mean, batch_var = tf.nn.moments(x, [0], name='moments_{0}'.format(name))
#else:
batch_mean, batch_var = tf.nn.moments(x, range(len(x.get_shape())-1), name='moments_{0}'.format(name))
ema = tf.train.ExponentialMovingAverage(decay=0.9) ema = tf.train.ExponentialMovingAverage(decay=0.9)
def mean_var_with_update(): def mean_var_with_update():
......
...@@ -70,7 +70,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): ...@@ -70,7 +70,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
if training or not isinstance(current_layer, Dropout): if training or not isinstance(current_layer, Dropout):
current_layer.create_variables(input_offset) current_layer.create_variables(input_offset)
input_offset = current_layer.get_graph() input_offset = current_layer.get_graph(training_phase=training)
if feature_layer is not None and k == feature_layer: if feature_layer is not None and k == feature_layer:
return input_offset return input_offset
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import numpy import numpy
from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, Disk, SiameseDisk, TripletDisk, ImageAugmentation from bob.learn.tensorflow.datashuffler import Memory, SiameseMemory, TripletMemory, Disk, SiameseDisk, TripletDisk, ImageAugmentation
from bob.learn.tensorflow.network import Chopra, Lenet from bob.learn.tensorflow.network import Chopra
from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss from bob.learn.tensorflow.loss import BaseLoss, ContrastiveLoss, TripletLoss
from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant from bob.learn.tensorflow.trainers import Trainer, SiameseTrainer, TripletTrainer, constant
......
...@@ -33,11 +33,13 @@ def scratch_network(): ...@@ -33,11 +33,13 @@ def scratch_network():
filters=10, filters=10,
activation=tf.nn.tanh, activation=tf.nn.tanh,
weights_initialization=Xavier(seed=seed, use_gpu=False), weights_initialization=Xavier(seed=seed, use_gpu=False),
bias_initialization=Constant(use_gpu=False))) bias_initialization=Constant(use_gpu=False)
))
scratch.add(FullyConnected(name="fc1", output_dim=10, scratch.add(FullyConnected(name="fc1", output_dim=10,
activation=None, activation=None,
weights_initialization=Xavier(seed=seed, use_gpu=False), weights_initialization=Xavier(seed=seed, use_gpu=False),
bias_initialization=Constant(use_gpu=False))) bias_initialization=Constant(use_gpu=False)
))
return scratch return scratch
...@@ -90,6 +92,7 @@ def test_cnn_trainer_scratch(): ...@@ -90,6 +92,7 @@ def test_cnn_trainer_scratch():
trainer.train(train_data_shuffler) trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory) accuracy = validate_network(validation_data, validation_labels, directory)
assert accuracy > 80 assert accuracy > 80
del scratch del scratch
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment