Skip to content
Snippets Groups Projects
Commit 289c1030 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

added support for 2, 3 layered lstm

parent 45082083
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -14,10 +14,10 @@ logger = logging.getLogger("bob.learn") ...@@ -14,10 +14,10 @@ logger = logging.getLogger("bob.learn")
def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20, def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20,
output_activation_size=10, batch_size=10, scope='rnn', output_activation_size=10, batch_size=10, scope='rnn',
weights_initializer=tf.random_normal, activation=tf.nn.relu, weights_initializer=tf.random_normal, activation=tf.nn.relu,
name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0): name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0, full_output=False):
""" """
""" """
return LSTM(lstm_cell_size=lstm_cell_size, output = LSTM(lstm_cell_size=lstm_cell_size,
num_time_steps=num_time_steps, num_time_steps=num_time_steps,
batch_size=batch_size, batch_size=batch_size,
lstm_fn=lstm_fn, lstm_fn=lstm_fn,
...@@ -30,6 +30,11 @@ def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_ ...@@ -30,6 +30,11 @@ def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_
output_dropout=output_dropout, output_dropout=output_dropout,
name=name, name=name,
reuse=reuse)(inputs) reuse=reuse)(inputs)
if full_output:
logger.info("LSTM: the number of the outputs: {0}".format(len(output)))
return output
logger.info("LSTM: the shape of the output: {0}".format(output[-1].shape))
return output[-1]
class LSTM(base.Layer): class LSTM(base.Layer):
...@@ -95,23 +100,30 @@ class LSTM(base.Layer): ...@@ -95,23 +100,30 @@ class LSTM(base.Layer):
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
logger.info("LSTM: the shape of the inputs: {0}".format(shape)) logger.info("LSTM: the shape of the inputs: {0}".format(shape))
input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM # if the input is already formatted correctly, just use it as is
if shape[1] == self.batch_size and shape[0] == self.num_time_steps:
if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1 inputs.set_shape((shape[0], None, shape[2]))
if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it logger.info("LSTM: undefine batch shape inputs: {0}".format(inputs.get_shape().as_list()))
inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2])) list_inputs = tf.unstack(inputs, self.num_time_steps, 0)
logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list())) # here we consider all special cases
else: else:
raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or ' input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM
'(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1
if input_time_steps % self.num_time_steps: if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it
raise ValueError('number of rows in one batch of input ({}) should be ' inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
'the same as the num_time_steps of LSTM ({})' logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list()))
.format(input_time_steps, self.num_time_steps)) else:
raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
# convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size) '(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
if input_time_steps % self.num_time_steps:
raise ValueError('number of rows in one batch of input ({}) should be '
'the same as the num_time_steps of LSTM ({})'
.format(input_time_steps, self.num_time_steps))
# convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
# run LSTM training on the batch of inputs # run LSTM training on the batch of inputs
# return the output (a list of self.num_time_steps outputs each of size input_vector_size) # return the output (a list of self.num_time_steps outputs each of size input_vector_size)
...@@ -123,6 +135,6 @@ class LSTM(base.Layer): ...@@ -123,6 +135,6 @@ class LSTM(base.Layer):
scope=self.scope) scope=self.scope)
# consider the output of the last cell # consider the output of the last cell
return outputs[-1] return outputs
# return tf.matmul(outputs[-1], self.output_activation_weights['out']) + self.output_activation_biases['out'] # return tf.matmul(outputs[-1], self.output_activation_weights['out']) + self.output_activation_biases['out']
...@@ -3,7 +3,7 @@ from .LightCNN9 import LightCNN9 ...@@ -3,7 +3,7 @@ from .LightCNN9 import LightCNN9
from .Dummy import Dummy from .Dummy import Dummy
from .MLP import MLP from .MLP import MLP
from .Embedding import Embedding from .Embedding import Embedding
from .lstm import simple_lstm_network from .lstm import simple_lstm_network, double_lstm_network, triple_lstm_network
from .lstm import RegularizedLoss from .lstm import RegularizedLoss
from .simplemlp import mlp_network from .simplemlp import mlp_network
from .simplecnn import simple2Dcnn_network from .simplecnn import simple2Dcnn_network
......
...@@ -69,3 +69,71 @@ def simple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10, ...@@ -69,3 +69,71 @@ def simple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse) weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph return graph
def double_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
num_time_steps=28, num_classes=10, seed=10, reuse=False,
dropout=False, input_dropout=1.0, output_dropout=1.0):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# First LSTM layer network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
# Second LSTM layer network of twice smaller size
graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
regularizer = None
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
def triple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
num_time_steps=28, num_classes=10, seed=10, reuse=False,
dropout=False, input_dropout=1.0, output_dropout=1.0):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# First LSTM layer network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
# Second LSTM layer network of twice smaller size
graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
# Third LSTM layer network three time smaller size
graph = lstm(graph, lstm_cell_size/4, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm3', name='sync_cell_l3',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
regularizer = None
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment