Skip to content
Snippets Groups Projects
Commit 289c1030 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

added support for 2, 3 layered lstm

parent 45082083
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -14,10 +14,10 @@ logger = logging.getLogger("bob.learn")
def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20,
output_activation_size=10, batch_size=10, scope='rnn',
weights_initializer=tf.random_normal, activation=tf.nn.relu,
name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0):
name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0, full_output=False):
"""
"""
return LSTM(lstm_cell_size=lstm_cell_size,
output = LSTM(lstm_cell_size=lstm_cell_size,
num_time_steps=num_time_steps,
batch_size=batch_size,
lstm_fn=lstm_fn,
......@@ -30,6 +30,11 @@ def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_
output_dropout=output_dropout,
name=name,
reuse=reuse)(inputs)
if full_output:
logger.info("LSTM: the number of the outputs: {0}".format(len(output)))
return output
logger.info("LSTM: the shape of the output: {0}".format(output[-1].shape))
return output[-1]
class LSTM(base.Layer):
......@@ -95,23 +100,30 @@ class LSTM(base.Layer):
shape = inputs.get_shape().as_list()
logger.info("LSTM: the shape of the inputs: {0}".format(shape))
input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM
if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1
if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it
inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list()))
else:
raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
'(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
if input_time_steps % self.num_time_steps:
raise ValueError('number of rows in one batch of input ({}) should be '
'the same as the num_time_steps of LSTM ({})'
.format(input_time_steps, self.num_time_steps))
# convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
# if the input is already formatted correctly, just use it as is
if shape[1] == self.batch_size and shape[0] == self.num_time_steps:
inputs.set_shape((shape[0], None, shape[2]))
logger.info("LSTM: undefine batch shape inputs: {0}".format(inputs.get_shape().as_list()))
list_inputs = tf.unstack(inputs, self.num_time_steps, 0)
# here we consider all special cases
else:
input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM
if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1
if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it
inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list()))
else:
raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
'(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
if input_time_steps % self.num_time_steps:
raise ValueError('number of rows in one batch of input ({}) should be '
'the same as the num_time_steps of LSTM ({})'
.format(input_time_steps, self.num_time_steps))
# convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
# run LSTM training on the batch of inputs
# return the output (a list of self.num_time_steps outputs each of size input_vector_size)
......@@ -123,6 +135,6 @@ class LSTM(base.Layer):
scope=self.scope)
# consider the output of the last cell
return outputs[-1]
return outputs
# return tf.matmul(outputs[-1], self.output_activation_weights['out']) + self.output_activation_biases['out']
......@@ -3,7 +3,7 @@ from .LightCNN9 import LightCNN9
from .Dummy import Dummy
from .MLP import MLP
from .Embedding import Embedding
from .lstm import simple_lstm_network
from .lstm import simple_lstm_network, double_lstm_network, triple_lstm_network
from .lstm import RegularizedLoss
from .simplemlp import mlp_network
from .simplecnn import simple2Dcnn_network
......
......@@ -69,3 +69,71 @@ def simple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
def double_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
num_time_steps=28, num_classes=10, seed=10, reuse=False,
dropout=False, input_dropout=1.0, output_dropout=1.0):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# First LSTM layer network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
# Second LSTM layer network of twice smaller size
graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
regularizer = None
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
def triple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
num_time_steps=28, num_classes=10, seed=10, reuse=False,
dropout=False, input_dropout=1.0, output_dropout=1.0):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# First LSTM layer network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm1', name='sync_cell_l1',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
# Second LSTM layer network of twice smaller size
graph = lstm(graph, lstm_cell_size/2, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm2', name='sync_cell_l2',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout, full_output=True)
# Third LSTM layer network three time smaller size
graph = lstm(graph, lstm_cell_size/4, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm3', name='sync_cell_l3',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
regularizer = None
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment