From e9f97a17bc46e71553ab0a9602ff80f6189de4eb Mon Sep 17 00:00:00 2001
From: Pavel Korshunov <pavel.korshunov@idiap.ch>
Date: Mon, 21 Nov 2016 17:00:53 +0100
Subject: [PATCH] Changed FullyConnected to allow different initializations

---
 bob/learn/tensorflow/layers/FullyConnected.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/bob/learn/tensorflow/layers/FullyConnected.py b/bob/learn/tensorflow/layers/FullyConnected.py
index 5e6890fb..cc2ec329 100644
--- a/bob/learn/tensorflow/layers/FullyConnected.py
+++ b/bob/learn/tensorflow/layers/FullyConnected.py
@@ -46,6 +46,7 @@ class FullyConnected(Layer):
                  weights_initialization=Xavier(),
                  bias_initialization=Constant(),
                  batch_norm=False,
+                 init_value=None,
                  use_gpu=False,
                  ):
 
@@ -61,11 +62,14 @@ class FullyConnected(Layer):
         self.W = None
         self.b = None
         self.shape = None
+        self.init_value = init_value
 
     def create_variables(self, input_layer):
         self.input_layer = input_layer
         if self.W is None:
             input_dim = reduce(mul, self.input_layer.get_shape().as_list()[1:])
+            if self.init_value is None:
+                self.init_value = input_dim
 
             variable = "W_" + str(self.name)
             if self.get_varible_by_name(variable) is not None:
@@ -73,7 +77,8 @@ class FullyConnected(Layer):
             else:
                 self.W = self.weights_initialization(shape=[input_dim, self.output_dim],
                                                      name="W_" + str(self.name),
-                                                     scope="W_" +str(self.name)
+                                                     scope="W_" +str(self.name),
+                                                     init_value=self.init_value
                                                      )
             # if self.activation is not None:
             variable = "b_" + str(self.name)
@@ -82,14 +87,15 @@ class FullyConnected(Layer):
             else:
                 self.b = self.bias_initialization(shape=[self.output_dim],
                                                   name="b_" + str(self.name),
-                                                  scope="b_" + str(self.name)
+                                                  scope="b_" + str(self.name),
+                                                  init_value=self.init_value
                                                  )
 
     def get_graph(self, training_phase=True):
 
         with tf.name_scope(str(self.name)):
 
-            if len(self.input_layer.get_shape()) == 4:
+            if len(self.input_layer.get_shape()) == 4 or len(self.input_layer.get_shape()) == 3:
                 shape = self.input_layer.get_shape().as_list()
                 fc = tf.reshape(self.input_layer, [-1, numpy.prod(shape[1:])])
             else:
-- 
GitLab