Skip to content
Snippets Groups Projects
Commit e3647d17 authored by Olivier Canévet's avatar Olivier Canévet
Browse files

[test_lstm] Fix time series generation

parent 19d7acb5
Branches
No related tags found
No related merge requests found
#!/usr/bin/env python
import sys
import numpy
import numpy as np
import random
from bob.learn.tensorflow.datashuffler import Memory, ScaleFactor
......@@ -45,7 +45,7 @@ def test_network(embedding, test_data, test_labels):
[data, labels] = test_data_shuffler.get_batch()
predictions = embedding(data)
logger.info("Test prediction size {}".format(predictions.shape))
acc = 100. * numpy.sum(numpy.argmax(predictions, axis=1) == labels) / predictions.shape[0]
acc = 100. * np.sum(np.argmax(predictions, axis=1) == labels) / predictions.shape[0]
# gt = tf.placeholder(tf.int64, [None, ])
# equal = tf.equal(tf.argmax(embedding.graph,1), gt)
......@@ -157,35 +157,37 @@ def generate_data_at_time(t, n_steps, dt):
return t, x, sequence_before, target
def generate_training_data(n_train, dt):
def generate_training_data(n_train, n_steps, dt):
"""
"""
t0, x0, s0, t0 = generate_data_at_time(0)
n_steps = s0.shape[0]
t0, x0, s0, t0 = generate_data_at_time(0, n_steps, dt)
dim = s0.shape[1]
times = np.zeros((n_train,1))
sequences = np.zeros((n_train, n_steps, dim))
targets = np.zeros((n_train, dim))
for n in range(batch_size):
for n in range(n_train):
t = 4*np.pi*(random.random() - 0.5)
t0, x0, ty0, sy0 = generate_data_at_time(t, n_steps, dt)
t[n,0] = t0
ty[n] = ty0
sy[n] = sy0
return t, ty, sy
t0, x0, s0, y0 = generate_data_at_time(t, n_steps, dt)
times[n,0] = t0
sequences[n] = s0
targets[n] = y0
return times, sequences, targets
def test_lstm_trainer_on_real_functions():
"""
"""
dt = 0.01
n_train = 7
n_steps = 3
# Generate train data in 3D matrix to use Memory class
train_data, train_targets = generate_training_data(n_train, n_steps)
times, train_data, train_targets = generate_training_data(n_train, n_steps, dt)
print(train_data)
# # Creating datashufflers
# train_data_shuffler = Memory(train_data, train_targets,
......@@ -193,6 +195,6 @@ def test_lstm_trainer_on_real_functions():
# batch_size=batch_size,
# normalizer=ScaleFactor())
test_lstm_trainer_on_mnist()
# test_lstm_trainer_on_mnist()
# test_lstm_trainer_on_real_functions()
test_lstm_trainer_on_real_functions()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment