Skip to content
Snippets Groups Projects
Commit 2867742c authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

LSTM layer with mnist test case

parent e79c19e7
Branches
No related tags found
No related merge requests found
Pipeline #
#!/usr/bin/env python
import tensorflow as tf
from tensorflow.python.layers import base
from tensorflow.python.framework import ops
def lstm(inputs, lstm_cell_size, lstm_fn=tf.contrib.rnn.BasicLSTMCell, num_time_steps=20,
output_activation_size=10, batch_size=10, scope='rnn',
weights_initializer=tf.random_normal, activation=tf.nn.relu, name=None, reuse=None):
"""
"""
return LSTM(lstm_cell_size=lstm_cell_size,
num_time_steps=num_time_steps,
batch_size=batch_size,
lstm_fn=lstm_fn,
scope=scope,
output_activation_size=output_activation_size,
weights_initializer=weights_initializer,
activation=activation,
name=name,
reuse=reuse)(inputs)
class LSTM(base.Layer):
"""
Basic LSTM layer in the format of tf-slim
"""
def __init__(self, lstm_cell_size,
num_time_steps=20,
batch_size=10,
lstm_fn=tf.contrib.rnn.BasicLSTMCell,
output_activation_size=10,
scope='rnn',
weights_initializer=tf.random_normal,
activation=tf.nn.relu,
name=None,
reuse=None,
**kwargs):
"""
:param lstm_cell_size [int]: size of the LSTM cell, i.e., the length of the output form each cell
:param batch_size [int]: input data batch size
:param num_time_steps [int]: the number of time steps of the input, i.e.,
the number of LSTM cells in one layer
"""
super(LSTM, self).__init__(name=name, trainable=False, **kwargs)
self.lstm_cell_size = lstm_cell_size
self.lstm = lstm_fn(self.lstm_cell_size, activation=activation, reuse=reuse, state_is_tuple=True, **kwargs)
self.batch_size = batch_size
self.num_time_steps = num_time_steps
self.scope = scope
hidden_state = tf.zeros([self.batch_size, self.lstm_cell_size])
current_state = tf.zeros([self.batch_size, self.lstm_cell_size])
self.states = hidden_state, current_state
# self.states = tf.zeros([self.batch_size, 2 * self.lstm_cell_size])
# self.states = None
self.sequence_length = None
self.output_activation_size = output_activation_size
# Define weights
self.output_activation_weights = {
'out': tf.Variable(weights_initializer([lstm_cell_size, self.output_activation_size]))
}
self.output_activation_biases = {
'out': tf.Variable(weights_initializer([self.output_activation_size]))
}
def __call__(self, inputs):
"""
:param inputs: The input is expected to have shape (batch_size, num_time_steps, input_vector_size).
"""
# shape inputs correctly
inputs = ops.convert_to_tensor(inputs)
shape = inputs.get_shape().as_list()
print("shape of the inputs: ", shape)
input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM
if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1
if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it
inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
print("after reshape, shape of the inputs: ", inputs.get_shape().as_list())
else:
raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
'(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
if input_time_steps % self.num_time_steps:
raise ValueError('number of rows in one batch of input ({}) should be '
'the same as the num_time_steps of LSTM ({})'
.format(input_time_steps, self.num_time_steps))
# convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
list_inputs = tf.unstack(inputs, self.num_time_steps, 1)
print("type of the input list: ", type(list_inputs))
print("Length of the converted list of inputs: ", len(list_inputs))
print("Size of each input: ", list_inputs[0].get_shape().as_list())
# list of length of each input sample
# self.sequence_length = shape[0]*[shape[1]]
# run LSTM training on the batch of inputs
# return the output (a list of self.num_time_steps outputs each of size input_vector_size)
# and remember the final states
# outputs, self.states = tf.nn.dynamic_rnn(self.lstm,
outputs, self.states = tf.contrib.rnn.static_rnn(self.lstm,
inputs=list_inputs,
initial_state=self.states,
# sequence_length=self.sequence_length,
dtype=tf.float32,
scope=self.scope)
# consider the output of the last cell
# apply linear activation on it
return outputs[-1]
# return tf.matmul(outputs[-1], self.output_activation_weights['out']) + self.output_activation_biases['out']
from .Layer import Layer
from .Conv1D import Conv1D
from .Maxout import maxout, maxout
from .Maxout import MaxOut, maxout
from LSTM import LSTM, lstm
# gets sphinx autodoc done right - don't remove it
......@@ -21,7 +22,9 @@ __appropriate__(
Layer,
Conv1D,
maxout,
Maxout
MaxOut,
LSTM,
lstm,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import numpy
from bob.learn.tensorflow.datashuffler import Memory, ImageAugmentation, ScaleFactor, Linear, TFRecord
from bob.learn.tensorflow.network import Embedding
from bob.learn.tensorflow.loss import BaseLoss, MeanSoftMaxLoss
from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.utils import load_mnist
from bob.learn.tensorflow.layers import lstm
import tensorflow as tf
import shutil
import os
import logging
logger = logging.getLogger("bob.learn")
"""
Some unit tests that create networks on the fly
"""
batch_size = 32
validation_batch_size = 32
iterations = 500
seed = 10
directory = "./temp/lstm_scratch"
slim = tf.contrib.slim
def scratch_lstm_network(train_data_shuffler, batch_size=16, num_classes=10, reuse=False):
inputs = train_data_shuffler("data", from_queue=False)
lstm_cell_size = 64
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# Creating an LSTM network
graph = lstm(inputs, lstm_cell_size, num_time_steps=28, batch_size=batch_size,
output_activation_size=lstm_cell_size, scope='lstm',
weights_initializer=initializer, activation=tf.nn.relu, reuse=reuse)
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
return graph
def validate_network(embedding, validation_data, validation_labels, input_shape=[None, 28, 28, 1],
normalizer=ScaleFactor()):
# Testing
validation_data_shuffler = Memory(validation_data, validation_labels,
input_shape=input_shape,
batch_size=validation_batch_size,
normalizer=normalizer)
[data, labels] = validation_data_shuffler.get_batch()
predictions = embedding(data)
accuracy = 100. * numpy.sum(numpy.argmax(predictions, axis=1) == labels) / predictions.shape[0]
logger.info("Validation accuracy = {0}".format(accuracy))
return accuracy
def test_lstm_trainer_scratch():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = numpy.reshape(train_data, (train_data.shape[0], 28, 28, 1))
# Creating datashufflers
data_augmentation = ImageAugmentation()
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[None, 28, 28, 1],
batch_size=batch_size,
data_augmentation=data_augmentation,
normalizer=ScaleFactor())
validation_data = numpy.reshape(validation_data, (validation_data.shape[0], 28, 28, 1))
# Create scratch network
graph = scratch_lstm_network(train_data_shuffler, batch_size=validation_batch_size)
# Setting the placeholders
embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
# Loss for the softmax
loss = MeanSoftMaxLoss()
# One graph trainer
trainer = Trainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory)
learning_rate = constant(0.001, name="regular_lr")
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=learning_rate,
optimizer=tf.train.AdamOptimizer(learning_rate),
)
trainer.train()
accuracy = validate_network(embedding, validation_data, validation_labels)
assert accuracy > 95
shutil.rmtree(directory)
del trainer
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
def test_lstm_trainer_scratch_tfrecord():
tf.reset_default_graph()
train_data, train_labels, validation_data, validation_labels = load_mnist()
train_data = train_data.astype("float32") * 0.00390625
validation_data = validation_data.astype("float32") * 0.00390625
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_tf_record(tfrecords_filename, data, labels):
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
# for i in range(train_data.shape[0]):
for i in range(6000):
img = data[i]
img_raw = img.tostring()
feature = {'train/data': _bytes_feature(img_raw),
'train/label': _int64_feature(labels[i])
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
writer.close()
tf.reset_default_graph()
# Creating the tf record
tfrecords_filename = "mnist_train.tfrecords"
create_tf_record(tfrecords_filename, train_data, train_labels)
filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=15, name="input")
tfrecords_filename_val = "mnist_validation.tfrecords"
create_tf_record(tfrecords_filename_val, validation_data, validation_labels)
filename_queue_val = tf.train.string_input_producer([tfrecords_filename_val], num_epochs=15,
name="input_validation")
# Creating the CNN using the TFRecord as input
train_data_shuffler = TFRecord(filename_queue=filename_queue,
batch_size=batch_size)
validation_data_shuffler = TFRecord(filename_queue=filename_queue_val,
batch_size=validation_batch_size)
graph = scratch_lstm_network(train_data_shuffler, batch_size=validation_batch_size)
validation_graph = scratch_lstm_network(validation_data_shuffler, batch_size=validation_batch_size, reuse=True)
# Setting the placeholders
# Loss for the softmax
loss = MeanSoftMaxLoss()
# One graph trainer
trainer = Trainer(train_data_shuffler,
validation_data_shuffler=validation_data_shuffler,
iterations=iterations, # It is supper fast
analizer=None,
temp_dir=directory)
learning_rate = constant(0.001, name="regular_lr")
trainer.create_network_from_scratch(graph=graph,
validation_graph=validation_graph,
loss=loss,
learning_rate=learning_rate,
optimizer=tf.train.AdamOptimizer(learning_rate),
)
trainer.train()
os.remove(tfrecords_filename)
os.remove(tfrecords_filename_val)
assert True
tf.reset_default_graph()
del trainer
assert len(tf.global_variables()) == 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment