Skip to content
Snippets Groups Projects
Commit e7d38e94 authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[test_lstm] Add simple LSTM example on MNIST

parent f04d1ae5
No related branches found
No related tags found
No related merge requests found
...@@ -5,13 +5,106 @@ from bob.learn.tensorflow.datashuffler import Memory, ScaleFactor ...@@ -5,13 +5,106 @@ from bob.learn.tensorflow.datashuffler import Memory, ScaleFactor
from bob.learn.tensorflow.network import MLP, Embedding from bob.learn.tensorflow.network import MLP, Embedding
from bob.learn.tensorflow.loss import BaseLoss from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer, constant from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.utils import load_mnist from bob.learn.tensorflow.utils import load_real_mnist, load_mnist
from bob.learn.tensorflow.utils.session import Session
import tensorflow as tf import tensorflow as tf
import shutil
import bob.core import logging
logger = bob.core.log.setup("LSTM") # bob.learn.tensorflow does not work logger = logging.getLogger("bob.learn.tf")
bob.core.log.set_verbosity_level(logger, 3)
# Data ######################################################################
logger.debug("Loading MNIST")
train_data, train_labels, validation_data, validation_labels = load_mnist(data_dir="mnist") batch_size = 128
iterations = 200
seed = 10
learning_rate = 0.001
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 27 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)
directory = "./temp/lstm"
######################################################################
def test_network(embedding, test_data, test_labels):
# Testing
test_data_shuffler = Memory(test_data, test_labels,
input_shape=[None, 28*28],
batch_size=test_data.shape[0],
normalizer=ScaleFactor())
[data, labels] = test_data_shuffler.get_batch()
predictions = embedding(data)
logger.info("Test prediction size {}".format(predictions.shape))
acc = 100. * numpy.sum(numpy.argmax(predictions, axis=1) == labels) / predictions.shape[0]
# gt = tf.placeholder(tf.int64, [None, ])
# equal = tf.equal(tf.argmax(embedding.graph,1), gt)
# accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))
# ss = Session.instance().session
# res = ss.run(embedding.graph, feed_dict={embedding.input: data})
# res2 = ss.run(accuracy, feed_dict={embedding.input: data, gt: labels})
# print("res {}".format(res.shape))
# print("acc2 {}".format(res2))
return acc
def test_dnn_trainer():
"""
"""
train_data, train_labels, test_data, test_labels = load_real_mnist(data_dir="mnist")
# Creating datashufflers
train_data_shuffler = Memory(train_data, train_labels,
input_shape=[None, 784],
batch_size=batch_size,
normalizer=ScaleFactor())
# Preparing the architecture
input_pl = train_data_shuffler("data", from_queue=False)
version = "lstm"
# Original code using MLP
if version == "mlp":
architecture = MLP(10, hidden_layers=[20, 40])
graph = architecture(input_pl)
elif version == "lstm":
W = tf.Variable(tf.random_normal([n_hidden, n_classes]))
b = tf.Variable(tf.random_normal([n_classes]))
graph = input_pl[:, n_input:]
graph = tf.reshape(graph, (-1, n_steps, n_input))
graph = tf.unstack(graph, n_steps, 1)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs, states = tf.nn.static_rnn(lstm_cell, graph, dtype=tf.float32)
graph = tf.matmul(outputs[-1], W) + b
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
# One graph trainer
trainer = Trainer(train_data_shuffler,
iterations=iterations,
analizer=None,
temp_dir=directory)
trainer.create_network_from_scratch(graph=graph,
loss=loss,
learning_rate=constant(learning_rate, name="regular_lr"),
optimizer=tf.train.AdamOptimizer(learning_rate))
trainer.train()
# Test
embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
accuracy = test_network(embedding, test_data, test_labels)
logger.info("Accuracy {}".format(accuracy))
test_dnn_trainer()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment