diff --git a/bob/learn/tensorflow/layers/LSTM.py b/bob/learn/tensorflow/layers/LSTM.py index de16784601e5ed58600ba830163cd8eda57c336c..0b07ad7aca3ae9963971d27eb6237e7604376756 100644 --- a/bob/learn/tensorflow/layers/LSTM.py +++ b/bob/learn/tensorflow/layers/LSTM.py @@ -14,10 +14,10 @@ logger = logging.getLogger("bob.learn") def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20, output_activation_size=10, batch_size=10, scope='rnn', weights_initializer=tf.random_normal, activation=tf.nn.relu, - name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0): + name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0, full_output=False): """ """ - return LSTM(lstm_cell_size=lstm_cell_size, + output = LSTM(lstm_cell_size=lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size, lstm_fn=lstm_fn, @@ -30,6 +30,11 @@ def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_ output_dropout=output_dropout, name=name, reuse=reuse)(inputs) + if full_output: + logger.info("LSTM: the number of the outputs: {0}".format(len(output))) + return output + logger.info("LSTM: the shape of the output: {0}".format(output[-1].shape)) + return output[-1] class LSTM(base.Layer): @@ -95,23 +100,30 @@ class LSTM(base.Layer): shape = inputs.get_shape().as_list() logger.info("LSTM: the shape of the inputs: {0}".format(shape)) - input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM - - if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1 - if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it - inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2])) - logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list())) - else: - raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or ' - '(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape)) - - if input_time_steps % self.num_time_steps: - raise ValueError('number of rows in one batch of input ({}) should be ' - 'the same as the num_time_steps of LSTM ({})' - .format(input_time_steps, self.num_time_steps)) - - # convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size) - list_inputs = tf.unstack(inputs, self.num_time_steps, 1) + # if the input is already formatted correctly, just use it as is + if shape[1] == self.batch_size and shape[0] == self.num_time_steps: + inputs.set_shape((shape[0], None, shape[2])) + logger.info("LSTM: undefine batch shape inputs: {0}".format(inputs.get_shape().as_list())) + list_inputs = tf.unstack(inputs, self.num_time_steps, 0) + # here we consider all special cases + else: + input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM + + if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1 + if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it + inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2])) + logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list())) + else: + raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or ' + '(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape)) + + if input_time_steps % self.num_time_steps: + raise ValueError('number of rows in one batch of input ({}) should be ' + 'the same as the num_time_steps of LSTM ({})' + .format(input_time_steps, self.num_time_steps)) + + # convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size) + list_inputs = tf.unstack(inputs, self.num_time_steps, 1) # run LSTM training on the batch of inputs # return the output (a list of self.num_time_steps outputs each of size input_vector_size) @@ -123,6 +135,6 @@ class LSTM(base.Layer): scope=self.scope) # consider the output of the last cell - return outputs[-1] + return outputs # return tf.matmul(outputs[-1], self.output_activation_weights['out']) + self.output_activation_biases['out'] diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index d83b5892af9c4aad34fb23f381c20a846ec9d366..b7fc9cc50bb7c7c5e6018ea4bee2fccf8c2b9c94 100644 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -3,7 +3,7 @@ from .LightCNN9 import LightCNN9 from .Dummy import Dummy from .MLP import MLP from .Embedding import Embedding -from .lstm import simple_lstm_network +from .lstm import simple_lstm_network, double_lstm_network, triple_lstm_network from .lstm import RegularizedLoss from .simplemlp import mlp_network from .simplecnn import simple2Dcnn_network diff --git a/bob/learn/tensorflow/network/lstm.py b/bob/learn/tensorflow/network/lstm.py index 7919767d13e7d6951506fe247c05777a57201d85..b0af20da3c91525776d84eecbb2d2409039e0a0b 100644 --- a/bob/learn/tensorflow/network/lstm.py +++ b/bob/learn/tensorflow/network/lstm.py @@ -69,3 +69,71 @@ def simple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10, weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse) return graph + + +def double_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10, + num_time_steps=28, num_classes=10, seed=10, reuse=False, + dropout=False, input_dropout=1.0, output_dropout=1.0): + + if isinstance(train_data_shuffler, tf.Tensor): + inputs = train_data_shuffler + else: + inputs = train_data_shuffler("data", from_queue=False) + + initializer = tf.contrib.layers.xavier_initializer(seed=seed) + + # First LSTM layer network + graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size, + output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1', + weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse, + dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True) + + # Second LSTM layer network of twice smaller size + graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size, + output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2', + weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse, + dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout) + + regularizer = None + # fully connect the LSTM output to the classes + graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1', + weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse) + + return graph + + +def triple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10, + num_time_steps=28, num_classes=10, seed=10, reuse=False, + dropout=False, input_dropout=1.0, output_dropout=1.0): + + if isinstance(train_data_shuffler, tf.Tensor): + inputs = train_data_shuffler + else: + inputs = train_data_shuffler("data", from_queue=False) + + initializer = tf.contrib.layers.xavier_initializer(seed=seed) + + # First LSTM layer network + graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size, + output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1', + weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse, + dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True) + + # Second LSTM layer network of twice smaller size + graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size, + output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2', + weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse, + dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True) + + # Third LSTM layer network three time smaller size + graph = lstm(graph, lstm_cell_size/4, num_time_steps=num_time_steps, batch_size=batch_size, + output_activation_size=num_classes, scope='lstm3', name='sync_cell_l3', + weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse, + dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout) + + regularizer = None + # fully connect the LSTM output to the classes + graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1', + weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse) + + return graph