Skip to content
Snippets Groups Projects
Commit f4cb9451 authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[layers] Add LSTM

parent 01f076ee
Branches
No related tags found
No related merge requests found
Pipeline #
#!/usr/bin/env python
import tensorflow as tf
from tensorflow.python.layers import base
def lstm(inputs, n_hidden, name=None):
"""
"""
return LSTM(n_hidden=n_hidden, name=name)(inputs)
class LSTM(base.Layer):
"""
"""
def __init__(self,
n_hidden,
name=None,
**kwargs):
super(LSTM, self).__init__(name=name,
**kwargs)
self.n_hidden = n_hidden
self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.n_hidden,
forget_bias=1.0)
def call(self, inputs, training=False):
"""
"""
outputs, states = tf.nn.static_rnn(self.lstm_cell,
inputs,
dtype=tf.float32)
return outputs[-1]
from .Layer import Layer
from .Conv1D import Conv1D
#from .Maxout import maxout
from .LSTM import lstm
# gets sphinx autodoc done right - don't remove it
......@@ -19,7 +20,7 @@ def __appropriate__(*args):
__appropriate__(
Layer,
Conv1D
Conv1D,
LSTM
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -7,6 +7,8 @@ from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.utils import load_real_mnist, load_mnist
from bob.learn.tensorflow.utils.session import Session
from bob.learn.tensorflow.layers import lstm
import tensorflow as tf
import shutil
......@@ -68,7 +70,7 @@ def test_dnn_trainer():
# Preparing the architecture
input_pl = train_data_shuffler("data", from_queue=False)
version = "lstm"
version = "bob"
# Original code using MLP
if version == "mlp":
......@@ -88,6 +90,13 @@ def test_dnn_trainer():
graph = outputs[-1]
graph = slim.fully_connected(graph, n_classes, activation_fn=None)
elif version == "bob":
slim = tf.contrib.slim
graph = input_pl[:, n_input:]
graph = tf.reshape(graph, (-1, n_steps, n_input))
graph = tf.unstack(graph, n_steps, 1)
graph = lstm(graph, n_hidden)
graph = slim.fully_connected(graph, n_classes, activation_fn=None)
# Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment