Commit e9f97a17 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

Changed FullyConnected to allow different initializations

parent b4ed2c55
...@@ -46,6 +46,7 @@ class FullyConnected(Layer): ...@@ -46,6 +46,7 @@ class FullyConnected(Layer):
weights_initialization=Xavier(), weights_initialization=Xavier(),
bias_initialization=Constant(), bias_initialization=Constant(),
batch_norm=False, batch_norm=False,
init_value=None,
use_gpu=False, use_gpu=False,
): ):
...@@ -61,11 +62,14 @@ class FullyConnected(Layer): ...@@ -61,11 +62,14 @@ class FullyConnected(Layer):
self.W = None self.W = None
self.b = None self.b = None
self.shape = None self.shape = None
self.init_value = init_value
def create_variables(self, input_layer): def create_variables(self, input_layer):
self.input_layer = input_layer self.input_layer = input_layer
if self.W is None: if self.W is None:
input_dim = reduce(mul, self.input_layer.get_shape().as_list()[1:]) input_dim = reduce(mul, self.input_layer.get_shape().as_list()[1:])
if self.init_value is None:
self.init_value = input_dim
variable = "W_" + str(self.name) variable = "W_" + str(self.name)
if self.get_varible_by_name(variable) is not None: if self.get_varible_by_name(variable) is not None:
...@@ -73,7 +77,8 @@ class FullyConnected(Layer): ...@@ -73,7 +77,8 @@ class FullyConnected(Layer):
else: else:
self.W = self.weights_initialization(shape=[input_dim, self.output_dim], self.W = self.weights_initialization(shape=[input_dim, self.output_dim],
name="W_" + str(self.name), name="W_" + str(self.name),
scope="W_" +str(self.name) scope="W_" +str(self.name),
init_value=self.init_value
) )
# if self.activation is not None: # if self.activation is not None:
variable = "b_" + str(self.name) variable = "b_" + str(self.name)
...@@ -82,14 +87,15 @@ class FullyConnected(Layer): ...@@ -82,14 +87,15 @@ class FullyConnected(Layer):
else: else:
self.b = self.bias_initialization(shape=[self.output_dim], self.b = self.bias_initialization(shape=[self.output_dim],
name="b_" + str(self.name), name="b_" + str(self.name),
scope="b_" + str(self.name) scope="b_" + str(self.name),
init_value=self.init_value
) )
def get_graph(self, training_phase=True): def get_graph(self, training_phase=True):
with tf.name_scope(str(self.name)): with tf.name_scope(str(self.name)):
if len(self.input_layer.get_shape()) == 4: if len(self.input_layer.get_shape()) == 4 or len(self.input_layer.get_shape()) == 3:
shape = self.input_layer.get_shape().as_list() shape = self.input_layer.get_shape().as_list()
fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])]) fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])])
else: else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment