Skip to content
Snippets Groups Projects
Commit e9f97a17 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

Changed FullyConnected to allow different initializations

parent b4ed2c55
Branches
Tags
1 merge request!2Added support for audio databases
......@@ -46,6 +46,7 @@ class FullyConnected(Layer):
weights_initialization=Xavier(),
bias_initialization=Constant(),
batch_norm=False,
init_value=None,
use_gpu=False,
):
......@@ -61,11 +62,14 @@ class FullyConnected(Layer):
self.W = None
self.b = None
self.shape = None
self.init_value = init_value
def create_variables(self, input_layer):
self.input_layer = input_layer
if self.W is None:
input_dim = reduce(mul, self.input_layer.get_shape().as_list()[1:])
if self.init_value is None:
self.init_value = input_dim
variable = "W_" + str(self.name)
if self.get_varible_by_name(variable) is not None:
......@@ -73,7 +77,8 @@ class FullyConnected(Layer):
else:
self.W = self.weights_initialization(shape=[input_dim, self.output_dim],
name="W_" + str(self.name),
scope="W_" +str(self.name)
scope="W_" +str(self.name),
init_value=self.init_value
)
# if self.activation is not None:
variable = "b_" + str(self.name)
......@@ -82,14 +87,15 @@ class FullyConnected(Layer):
else:
self.b = self.bias_initialization(shape=[self.output_dim],
name="b_" + str(self.name),
scope="b_" + str(self.name)
scope="b_" + str(self.name),
init_value=self.init_value
)
def get_graph(self, training_phase=True):
with tf.name_scope(str(self.name)):
if len(self.input_layer.get_shape()) == 4:
if len(self.input_layer.get_shape()) == 4 or len(self.input_layer.get_shape()) == 3:
shape = self.input_layer.get_shape().as_list()
fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])])
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment