Skip to content
Snippets Groups Projects
Commit e9f97a17 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

Changed FullyConnected to allow different initializations

parent b4ed2c55
No related branches found
No related tags found
1 merge request!2Added support for audio databases
...@@ -46,6 +46,7 @@ class FullyConnected(Layer): ...@@ -46,6 +46,7 @@ class FullyConnected(Layer):
weights_initialization=Xavier(), weights_initialization=Xavier(),
bias_initialization=Constant(), bias_initialization=Constant(),
batch_norm=False, batch_norm=False,
init_value=None,
use_gpu=False, use_gpu=False,
): ):
...@@ -61,11 +62,14 @@ class FullyConnected(Layer): ...@@ -61,11 +62,14 @@ class FullyConnected(Layer):
self.W = None self.W = None
self.b = None self.b = None
self.shape = None self.shape = None
self.init_value = init_value
def create_variables(self, input_layer): def create_variables(self, input_layer):
self.input_layer = input_layer self.input_layer = input_layer
if self.W is None: if self.W is None:
input_dim = reduce(mul, self.input_layer.get_shape().as_list()[1:]) input_dim = reduce(mul, self.input_layer.get_shape().as_list()[1:])
if self.init_value is None:
self.init_value = input_dim
variable = "W_" + str(self.name) variable = "W_" + str(self.name)
if self.get_varible_by_name(variable) is not None: if self.get_varible_by_name(variable) is not None:
...@@ -73,7 +77,8 @@ class FullyConnected(Layer): ...@@ -73,7 +77,8 @@ class FullyConnected(Layer):
else: else:
self.W = self.weights_initialization(shape=[input_dim, self.output_dim], self.W = self.weights_initialization(shape=[input_dim, self.output_dim],
name="W_" + str(self.name), name="W_" + str(self.name),
scope="W_" +str(self.name) scope="W_" +str(self.name),
init_value=self.init_value
) )
# if self.activation is not None: # if self.activation is not None:
variable = "b_" + str(self.name) variable = "b_" + str(self.name)
...@@ -82,14 +87,15 @@ class FullyConnected(Layer): ...@@ -82,14 +87,15 @@ class FullyConnected(Layer):
else: else:
self.b = self.bias_initialization(shape=[self.output_dim], self.b = self.bias_initialization(shape=[self.output_dim],
name="b_" + str(self.name), name="b_" + str(self.name),
scope="b_" + str(self.name) scope="b_" + str(self.name),
init_value=self.init_value
) )
def get_graph(self, training_phase=True): def get_graph(self, training_phase=True):
with tf.name_scope(str(self.name)): with tf.name_scope(str(self.name)):
if len(self.input_layer.get_shape()) == 4: if len(self.input_layer.get_shape()) == 4 or len(self.input_layer.get_shape()) == 3:
shape = self.input_layer.get_shape().as_list() shape = self.input_layer.get_shape().as_list()
fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])]) fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])])
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment