Commit 05868b93 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

support new tensorflow package

parent c6d935fc
Pipeline #12614 passed with stages
in 11 minutes and 17 seconds
......@@ -10,6 +10,7 @@ import bob.io.base
# import tensorflow as tf
import os
import logging
logger = logging.getLogger("bob.pad.voice")
......@@ -30,24 +31,59 @@ class TensorflowAlgo(Algorithm):
)
self.data_reader = None
# self.session = tf.Session()
self.session = None
self.dnn_model = None
self.data_placeholder = None
# def __del__(self):
# self.session.close()
def simple_lstm_network(self, train_data_shuffler, batch_size=10, lstm_cell_size=64,
num_time_steps=28, num_classes=10, seed=10, reuse=False):
import tensorflow as tf
from bob.learn.tensorflow.layers import lstm
slim = tf.contrib.slim
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# Creating an LSTM network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=lstm_cell_size, scope='lstm',
weights_initializer=initializer, activation=tf.nn.relu, reuse=reuse)
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
return graph
def _check_feature(self, feature):
"""Checks that the features are appropriate."""
if not isinstance(feature, numpy.ndarray) or feature.ndim != 1 or feature.dtype != numpy.float32:
raise ValueError("The given feature is not appropriate", feature)
return True
def restore_trained_model(self, projector_file):
import tensorflow as tf
if self.session is None:
self.session = tf.Session()
data_pl = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = self.simple_lstm_network(data_pl, reuse=False)
self.session.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(projector_file + ".meta", clear_devices=True)
saver.restore(self.session, projector_file)
return graph, data_pl
def load_projector(self, projector_file):
logger.info("Loading pretrained model from {0}".format(projector_file))
from bob.learn.tensorflow.network.SequenceNetwork import SequenceNetwork
self.dnn_model = SequenceNetwork()
# self.dnn_model.load_hdf5(bob.io.base.HDF5File(projector_file), shape=[1, 6560, 1])
self.dnn_model.load(projector_file, True)
self.dnn_model, self.data_placeholder = self.restore_trained_model(projector_file)
def project_feature(self, feature):
......@@ -55,12 +91,16 @@ class TensorflowAlgo(Algorithm):
from bob.learn.tensorflow.datashuffler import DiskAudio
if not self.data_reader:
self.data_reader = DiskAudio([0], [0])
frames, labels = self.data_reader.extract_frames_from_wav(feature, 0)
# frames, labels = self.data_reader.extract_frames_from_wav(feature, 0)
frames, labels = self.data_reader.split_features_in_windows(feature, 0, )
frames = numpy.asarray(frames)
logger.debug(" .... And %d frames are extracted to pass into DNN model" % frames.shape[0])
frames = numpy.reshape(frames, (frames.shape[0], -1, 1))
forward_output = self.dnn_model(frames)
# return tf.nn.log_softmax(tf.nn.log_softmax(forward_output)).eval(session=self.session)
if self.session is not None:
forward_output = self.session.run(self.dnn_model, feed_dict={self.data_placeholder: frames})
else:
raise ValueError("Tensorflow session was not initialized, so cannot project on DNN model!")
return forward_output
def project(self, feature):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment