Skip to content
Snippets Groups Projects
Commit 8e81ae89 authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[test_lstm] Use rnn from bob correcly

parent 8e673720
Branches
No related tags found
No related merge requests found
Pipeline #
#!/usr/bin/env python
import tensorflow as tf
from tensorflow.python.layers import base
# def lstm(inputs, n_hidden, name=None):
# """
# """
# return LSTM(n_hidden=n_hidden, name=name)(inputs)
def rnn(inputs, n_hidden, cell_fn, cell_args, name=None):
def rnn(inputs, n_hidden,
cell_fn = tf.nn.rnn_cell.BasicLSTMCell,
cell_args = { "forget_bias": 1.0, },
name = None):
"""
"""
return RNN(n_hidden=n_hidden,
return RNN(n_hidden = n_hidden,
cell_fn = cell_fn,
cell_args = cell_args,
name=name)(inputs)
......
......@@ -7,7 +7,7 @@ from bob.learn.tensorflow.loss import BaseLoss
from bob.learn.tensorflow.trainers import Trainer, constant
from bob.learn.tensorflow.utils import load_real_mnist, load_mnist
from bob.learn.tensorflow.utils.session import Session
from bob.learn.tensorflow.layers import lstm
from bob.learn.tensorflow.layers import rnn
import tensorflow as tf
import shutil
......@@ -95,7 +95,7 @@ def test_dnn_trainer():
graph = input_pl[:, n_input:]
graph = tf.reshape(graph, (-1, n_steps, n_input))
graph = tf.unstack(graph, n_steps, 1)
graph = lstm(graph, n_hidden)
graph = rnn(graph, n_hidden)
graph = slim.fully_connected(graph, n_classes, activation_fn=None)
# Loss for the softmax
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment