diff --git a/bob/learn/tensorflow/layers/LSTM.py b/bob/learn/tensorflow/layers/LSTM.py
index de16784601e5ed58600ba830163cd8eda57c336c..0b07ad7aca3ae9963971d27eb6237e7604376756 100644
--- a/bob/learn/tensorflow/layers/LSTM.py
+++ b/bob/learn/tensorflow/layers/LSTM.py
@@ -14,10 +14,10 @@ logger = logging.getLogger("bob.learn")
 def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20,
          output_activation_size=10, batch_size=10, scope='rnn',
          weights_initializer=tf.random_normal, activation=tf.nn.relu,
-         name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0):
+         name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0, full_output=False):
     """
     """
-    return LSTM(lstm_cell_size=lstm_cell_size,
+    output = LSTM(lstm_cell_size=lstm_cell_size,
                 num_time_steps=num_time_steps,
                 batch_size=batch_size,
                 lstm_fn=lstm_fn,
@@ -30,6 +30,11 @@ def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_
                 output_dropout=output_dropout,
                 name=name,
                 reuse=reuse)(inputs)
+    if full_output:
+        logger.info("LSTM: the number of the outputs: {0}".format(len(output)))
+        return output
+    logger.info("LSTM: the shape of the output: {0}".format(output[-1].shape))
+    return output[-1]
 
 
 class LSTM(base.Layer):
@@ -95,23 +100,30 @@ class LSTM(base.Layer):
         shape = inputs.get_shape().as_list()
         logger.info("LSTM: the shape of the inputs: {0}".format(shape))
 
-        input_time_steps = shape[1]  # second dimension must be the number of time steps in LSTM
-
-        if len(shape) == 4:  # when inputs shape is 4, the last dimension must be 1
-            if shape[-1] == 1:  # we accept last dimension to be 1, then we just reshape it
-                inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
-                logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list()))
-            else:
-                raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
-                                 '(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
-
-        if input_time_steps % self.num_time_steps:
-            raise ValueError('number of rows in one batch of input ({}) should be '
-                             'the same as the num_time_steps of LSTM ({})'
-                             .format(input_time_steps, self.num_time_steps))
-
-        # convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
-        list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
+        # if the input is already formatted correctly, just use it as is
+        if shape[1] == self.batch_size and shape[0] == self.num_time_steps:
+            inputs.set_shape((shape[0], None, shape[2]))
+            logger.info("LSTM: undefine batch shape inputs: {0}".format(inputs.get_shape().as_list()))
+            list_inputs = tf.unstack(inputs, self.num_time_steps, 0)
+        # here we consider all special cases
+        else:
+            input_time_steps = shape[1]  # second dimension must be the number of time steps in LSTM
+
+            if len(shape) == 4:  # when inputs shape is 4, the last dimension must be 1
+                if shape[-1] == 1:  # we accept last dimension to be 1, then we just reshape it
+                    inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
+                    logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list()))
+                else:
+                    raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
+                                     '(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
+
+            if input_time_steps % self.num_time_steps:
+                raise ValueError('number of rows in one batch of input ({}) should be '
+                                 'the same as the num_time_steps of LSTM ({})'
+                                 .format(input_time_steps, self.num_time_steps))
+
+            # convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
+            list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
 
         # run LSTM training on the batch of inputs
         # return the output (a list of self.num_time_steps outputs each of size input_vector_size)
@@ -123,6 +135,6 @@ class LSTM(base.Layer):
                                                          scope=self.scope)
 
         # consider the output of the last cell
-        return outputs[-1]
+        return outputs
         # return tf.matmul(outputs[-1], self.output_activation_weights['out']) + self.output_activation_biases['out']
 
diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py
index d83b5892af9c4aad34fb23f381c20a846ec9d366..b7fc9cc50bb7c7c5e6018ea4bee2fccf8c2b9c94 100644
--- a/bob/learn/tensorflow/network/__init__.py
+++ b/bob/learn/tensorflow/network/__init__.py
@@ -3,7 +3,7 @@ from .LightCNN9 import LightCNN9
 from .Dummy import Dummy
 from .MLP import MLP
 from .Embedding import Embedding
-from .lstm import simple_lstm_network
+from .lstm import simple_lstm_network, double_lstm_network, triple_lstm_network
 from .lstm import RegularizedLoss
 from .simplemlp import mlp_network
 from .simplecnn import simple2Dcnn_network
diff --git a/bob/learn/tensorflow/network/lstm.py b/bob/learn/tensorflow/network/lstm.py
index 7919767d13e7d6951506fe247c05777a57201d85..b0af20da3c91525776d84eecbb2d2409039e0a0b 100644
--- a/bob/learn/tensorflow/network/lstm.py
+++ b/bob/learn/tensorflow/network/lstm.py
@@ -69,3 +69,71 @@ def simple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
                                  weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
 
     return graph
+
+
+def double_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
+                        num_time_steps=28, num_classes=10, seed=10, reuse=False,
+                        dropout=False, input_dropout=1.0, output_dropout=1.0):
+
+    if isinstance(train_data_shuffler, tf.Tensor):
+        inputs = train_data_shuffler
+    else:
+        inputs = train_data_shuffler("data", from_queue=False)
+
+    initializer = tf.contrib.layers.xavier_initializer(seed=seed)
+
+    # First LSTM layer network
+    graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
+                 output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1',
+                 weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
+                 dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
+
+    # Second LSTM layer network of twice smaller size
+    graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size,
+                 output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2',
+                 weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
+                 dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
+
+    regularizer = None
+    # fully connect the LSTM output to the classes
+    graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
+                                 weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
+
+    return graph
+
+
+def triple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
+                        num_time_steps=28, num_classes=10, seed=10, reuse=False,
+                        dropout=False, input_dropout=1.0, output_dropout=1.0):
+
+    if isinstance(train_data_shuffler, tf.Tensor):
+        inputs = train_data_shuffler
+    else:
+        inputs = train_data_shuffler("data", from_queue=False)
+
+    initializer = tf.contrib.layers.xavier_initializer(seed=seed)
+
+    # First LSTM layer network
+    graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
+                 output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1',
+                 weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
+                 dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
+
+    # Second LSTM layer network of twice smaller size
+    graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size,
+                 output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2',
+                 weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
+                 dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
+
+    # Third LSTM layer network three time smaller size
+    graph = lstm(graph, lstm_cell_size/4, num_time_steps=num_time_steps, batch_size=batch_size,
+                 output_activation_size=num_classes, scope='lstm3', name='sync_cell_l3',
+                 weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
+                 dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
+
+    regularizer = None
+    # fully connect the LSTM output to the classes
+    graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
+                                 weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
+
+    return graph