Skip to content
Snippets Groups Projects
Commit bce79d8d authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[layers] Rename lstm to rnn to be more general

parent f4cb9451
Branches
Tags
No related merge requests found
...@@ -4,32 +4,43 @@ import tensorflow as tf ...@@ -4,32 +4,43 @@ import tensorflow as tf
from tensorflow.python.layers import base from tensorflow.python.layers import base
def lstm(inputs, n_hidden, name=None): # def lstm(inputs, n_hidden, name=None):
# """
# """
# return LSTM(n_hidden=n_hidden, name=name)(inputs)
def rnn(inputs, n_hidden, cell_fn, cell_args, name=None):
""" """
""" """
return LSTM(n_hidden=n_hidden, name=name)(inputs) return RNN(n_hidden=n_hidden,
cell_fn = cell_fn,
cell_args = cell_args,
name=name)(inputs)
class LSTM(base.Layer): class RNN(base.Layer):
""" """
Inspired from tensorlayer/tensorlayer/layers.py
""" """
def __init__(self, def __init__(self, n_hidden,
n_hidden, cell_fn = tf.nn.rnn_cell.BasicLSTMCell,
cell_args = { "forget_bias": 1.0, },
name=None, name=None,
**kwargs): **kwargs):
super(LSTM, self).__init__(name=name, """
**kwargs) """
super(RNN, self).__init__(name=name, **kwargs)
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.n_hidden, self.cell = cell_fn(num_units = self.n_hidden, **kwargs)
forget_bias=1.0)
def call(self, inputs, training=False): def call(self, inputs, training=False):
""" """
""" """
outputs, states = tf.nn.static_rnn(self.cell,
outputs, states = tf.nn.static_rnn(self.lstm_cell,
inputs, inputs,
dtype=tf.float32) dtype=tf.float32)
# Compare to tensorlayer, it is as if return_last = True
return outputs[-1] return outputs[-1]
from .Layer import Layer from .Layer import Layer
from .Conv1D import Conv1D from .Conv1D import Conv1D
#from .Maxout import maxout #from .Maxout import maxout
from .LSTM import lstm from .RNN import rnn
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
...@@ -21,6 +21,6 @@ def __appropriate__(*args): ...@@ -21,6 +21,6 @@ def __appropriate__(*args):
__appropriate__( __appropriate__(
Layer, Layer,
Conv1D, Conv1D,
LSTM RNN
) )
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -119,5 +119,6 @@ def test_dnn_trainer(): ...@@ -119,5 +119,6 @@ def test_dnn_trainer():
embedding = Embedding(train_data_shuffler("data", from_queue=False), graph) embedding = Embedding(train_data_shuffler("data", from_queue=False), graph)
accuracy = test_network(embedding, test_data, test_labels) accuracy = test_network(embedding, test_data, test_labels)
logger.info("Accuracy {}".format(accuracy)) logger.info("Accuracy {}".format(accuracy))
assert accuracy > 0.8
test_dnn_trainer() test_dnn_trainer()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment