Skip to content
Snippets Groups Projects
Commit ccc3e3ce authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

added dropout to lstm

parent 062b3719
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -13,7 +13,8 @@ logger = logging.getLogger("bob.learn")
def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20,
output_activation_size=10, batch_size=10, scope='rnn',
weights_initializer=tf.random_normal, activation=tf.nn.relu, name=None, reuse=None):
weights_initializer=tf.random_normal, activation=tf.nn.relu,
name=None, reuse=None, dropout=False, input_dropout=1.0, output_dropout=1.0):
"""
"""
return LSTM(lstm_cell_size=lstm_cell_size,
......@@ -24,6 +25,9 @@ def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_
output_activation_size=output_activation_size,
weights_initializer=weights_initializer,
activation=activation,
dropout=dropout,
input_dropout=input_dropout,
output_dropout=output_dropout,
name=name,
reuse=reuse)(inputs)
......@@ -43,6 +47,9 @@ class LSTM(base.Layer):
activation=tf.nn.relu,
name=None,
reuse=None,
dropout=False,
input_dropout=1.0,
output_dropout=1.0,
**kwargs):
"""
:param lstm_cell_size [int]: size of the LSTM cell, i.e., the length of the output form each cell
......@@ -54,6 +61,9 @@ class LSTM(base.Layer):
self.lstm_cell_size = lstm_cell_size
self.lstm = lstm_fn(self.lstm_cell_size, activation=activation, reuse=reuse, state_is_tuple=True, **kwargs)
if dropout:
self.lstm = tf.nn.rnn_cell.DropoutWrapper(self.lstm, input_keep_prob=input_dropout,
output_keep_prob=output_dropout)
self.batch_size = batch_size
self.num_time_steps = num_time_steps
self.scope = scope
......
......@@ -15,13 +15,15 @@ import shutil
import os
import logging
logger = logging.getLogger("bob.learn")
slim = tf.contrib.slim
def scratch_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
num_time_steps=28, num_classes=10, seed=10, reuse=False):
num_time_steps=28, num_classes=10, seed=10, reuse=False,
dropout=False, input_dropout=1.0, output_dropout=1.0):
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
......@@ -29,7 +31,8 @@ def scratch_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
# Creating an LSTM network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm',
weights_initializer=initializer, activation=tf.nn.relu, reuse=reuse)
weights_initializer=initializer, activation=tf.nn.relu, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
......@@ -38,14 +41,15 @@ def scratch_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
return graph
def validate_network(embedding, validation_data, validation_labels,
input_shape=[None, 28, 28, 1], validation_batch_size=10,
normalizer=ScaleFactor()):
def validate_network(embedding, validation_data_shuffler):
# def validate_network(embedding, validation_data, validation_labels,
# input_shape=[None, 28, 28, 1], validation_batch_size=10,
# normalizer=ScaleFactor()):
# Testing
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=input_shape,
batch_size=validation_batch_size,
normalizer=normalizer)
# validation_data_shuffler = Memory(validation_data, validation_labels,
# input_shape=input_shape,
# batch_size=validation_batch_size,
# normalizer=normalizer)
valid_range = 10
accuracy = 0
......@@ -107,7 +111,7 @@ def test_lstm_trainer():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], ) + tuple(input_shape[1:]))
train_data = numpy.reshape(train_data, (train_data.shape[0],) + tuple(input_shape[1:]))
# Creating datashufflers
train_data_shuffler = Memory(train_data, train_labels,
......@@ -116,17 +120,34 @@ def test_lstm_trainer():
data_augmentation=None,
normalizer=ScaleFactor())
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], ) + tuple(input_shape[1:]))
validation_data = numpy.reshape(validation_data, (validation_data.shape[0],) + tuple(input_shape[1:]))
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=input_shape,
batch_size=batch_size,
data_augmentation=None,
normalizer=ScaleFactor())
# Create scratch network
graph = scratch_lstm_network(train_data_shuffler,
lstm_cell_size=lstm_cell_size,
batch_size=batch_size,
num_time_steps=num_time_steps,
seed=seed,
num_classes=num_classes)
num_classes=num_classes,
dropout=True,
input_dropout=0.8,
output_dropout=0.8)
# Setting the placeholders
embedding = Embedding(train_data_shuffler("data", from_queue=False), tf.nn.softmax(graph), normalizer=None)
# embedding = Embedding(train_data_shuffler("data", from_queue=False), graph, normalizer=None)
validation_graph = scratch_lstm_network(validation_data_shuffler,
lstm_cell_size=lstm_cell_size,
batch_size=batch_size,
num_time_steps=num_time_steps,
seed=seed,
num_classes=num_classes,
reuse=True)
embedding = Embedding(validation_data_shuffler("data", from_queue=False), validation_graph, normalizer=None)
# Loss for the softmax
loss = MeanSoftMaxLoss()
......@@ -146,10 +167,11 @@ def test_lstm_trainer():
)
trainer.train()
accuracy = validate_network(embedding, validation_data, validation_labels, validation_batch_size=batch_size)
# accuracy = validate_network(embedding, validation_data, validation_labels, validation_batch_size=batch_size)
accuracy = validate_network(embedding, validation_data_shuffler)
logger.info("Ran for {0} full epochs".format(train_data_shuffler.epoch))
assert accuracy > 30
assert accuracy > 99
shutil.rmtree(directory)
del trainer
tf.reset_default_graph()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment