Commit 8afa34ab authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

working version of tf classifier

parent 05868b93
Pipeline #12682 passed with stages
in 13 minutes and 13 seconds
......@@ -6,11 +6,9 @@
from bob.pad.base.algorithm import Algorithm
import numpy
import bob.io.base
# import tensorflow as tf
import os
import logging
logger = logging.getLogger("bob.pad.voice")
......@@ -53,18 +51,18 @@ class TensorflowAlgo(Algorithm):
# Creating an LSTM network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=lstm_cell_size, scope='lstm',
weights_initializer=initializer, activation=tf.nn.relu, reuse=reuse)
output_activation_size=num_classes, scope='lstm',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse)
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
# graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
# weights_initializer=initializer, reuse=reuse)
return graph
def _check_feature(self, feature):
"""Checks that the features are appropriate."""
if not isinstance(feature, numpy.ndarray) or feature.ndim != 1 or feature.dtype != numpy.float32:
if not isinstance(feature, numpy.ndarray) or feature.ndim != 2 or feature.dtype != numpy.float32:
raise ValueError("The given feature is not appropriate", feature)
return True
......@@ -72,13 +70,16 @@ class TensorflowAlgo(Algorithm):
import tensorflow as tf
if self.session is None:
self.session = tf.Session()
data_pl = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = self.simple_lstm_network(data_pl, reuse=False)
data_pl = tf.placeholder(tf.float32, shape=(None, 60, 102))
graph = self.simple_lstm_network(data_pl, batch_size=1,
lstm_cell_size=60, num_time_steps=60,
num_classes=2, reuse=False)
self.session.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(projector_file + ".meta", clear_devices=True)
saver = tf.train.Saver()
# saver = tf.train.import_meta_graph(projector_file + ".meta", clear_devices=True)
saver.restore(self.session, projector_file)
return graph, data_pl
return tf.nn.softmax(graph, name="softmax"), data_pl
def load_projector(self, projector_file):
logger.info("Loading pretrained model from {0}".format(projector_file))
......@@ -87,21 +88,27 @@ class TensorflowAlgo(Algorithm):
def project_feature(self, feature):
logger.debug(" .... Projecting %d features vector" % feature.shape[0])
logger.info(" .... Projecting %d features vector" % feature.shape[0])
from bob.learn.tensorflow.datashuffler import DiskAudio
if not self.data_reader:
self.data_reader = DiskAudio([0], [0])
self.data_reader = DiskAudio([0], [0], [1, 60, 102])
# frames, labels = self.data_reader.extract_frames_from_wav(feature, 0)
frames, labels = self.data_reader.split_features_in_windows(feature, 0, )
frames = numpy.asarray(frames)
logger.debug(" .... And %d frames are extracted to pass into DNN model" % frames.shape[0])
frames = numpy.reshape(frames, (frames.shape[0], -1, 1))
if self.session is not None:
forward_output = self.session.run(self.dnn_model, feed_dict={self.data_placeholder: frames})
else:
raise ValueError("Tensorflow session was not initialized, so cannot project on DNN model!")
return forward_output
frames, labels = self.data_reader.split_features_in_windows(feature, 0, 60)
# frames = numpy.asarray(frames)
# logger.info(" .... And frames of shape {0} are extracted to pass into DNN model".format(frames.shape))
projections = numpy.zeros((len(frames), 2), dtype=numpy.float32)
for frame, i in zip(frames, range(len(frames))):
frame = numpy.reshape(frame, ([1] + list(frames[0].shape)))
# frames = numpy.reshape(frames, (frames.shape[0], -1, 1))
logger.info(" .... projecting frame of shape {0} onto DNN model".format(frame.shape))
if self.session is not None:
forward_output = self.session.run(self.dnn_model, feed_dict={self.data_placeholder: frame})
projections[i]=forward_output[0]
else:
raise ValueError("Tensorflow session was not initialized, so cannot project on DNN model!")
logger.info("Projected scores {0}".format(projections))
return numpy.asarray(projections, dtype=numpy.float32)
def project(self, feature):
"""project(feature) -> projected
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment