Skip to content
Snippets Groups Projects
Commit 01f076ee authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[test_lstm] Use slim fully connected

parent 7d70c020
Branches
Tags v2.1.2
No related merge requests found
...@@ -76,14 +76,18 @@ def test_dnn_trainer(): ...@@ -76,14 +76,18 @@ def test_dnn_trainer():
graph = architecture(input_pl) graph = architecture(input_pl)
elif version == "lstm": elif version == "lstm":
W = tf.Variable(tf.random_normal([n_hidden, n_classes])) slim = tf.contrib.slim
b = tf.Variable(tf.random_normal([n_classes])) # W = tf.Variable(tf.random_normal([n_hidden, n_classes]))
# b = tf.Variable(tf.random_normal([n_classes]))
graph = input_pl[:, n_input:] graph = input_pl[:, n_input:]
graph = tf.reshape(graph, (-1, n_steps, n_input)) graph = tf.reshape(graph, (-1, n_steps, n_input))
graph = tf.unstack(graph, n_steps, 1) graph = tf.unstack(graph, n_steps, 1)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0) lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs, states = tf.nn.static_rnn(lstm_cell, graph, dtype=tf.float32) outputs, states = tf.nn.static_rnn(lstm_cell, graph, dtype=tf.float32)
graph = tf.matmul(outputs[-1], W) + b # graph = tf.matmul(outputs[-1], W) + b
graph = outputs[-1]
graph = slim.fully_connected(graph, n_classes, activation_fn=None)
# Loss for the softmax # Loss for the softmax
loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean) loss = BaseLoss(tf.nn.sparse_softmax_cross_entropy_with_logits, tf.reduce_mean)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment