Commit e88c94f7 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

merged

parents 97356a49 5b428b7c
Pipeline #14720 passed with stages
in 6 minutes and 18 seconds
......@@ -14,12 +14,13 @@ import logging
logger = logging.getLogger("bob.pad.voice")
class LSTMEval(Algorithm):
class TensorflowEval(Algorithm):
"""This class is for evaluating data stored in tensorflow tfrecord format using a pre-trained LSTM model."""
def __init__(self,
architecture_name="mlp",
input_shape=[200, 81], # [temporal_length, feature_size]
lstm_network_size=60, # the output size of LSTM cell
network_size=60, # the output size of LSTM cell
normalization_file=None, # file with normalization parameters from train set
**kwargs):
"""Generates a test value that is read and written"""
......@@ -32,17 +33,19 @@ class LSTMEval(Algorithm):
**kwargs
)
self.architecture_name = architecture_name
self.input_shape = input_shape
self.num_time_steps = input_shape[0]
self.lstm_network_size = lstm_network_size
self.network_size = network_size
self.data_std = None
# import ipdb
# ipdb.set_trace()
features_length = input_shape[1]
if normalization_file and os.path.exists(normalization_file):
logger.info("Loading normalization file '%s' " % normalization_file)
npzfile = numpy.load(normalization_file)
self.data_mean = npzfile['data_mean']
self.data_std = npzfile['data_std']
self.data_std = numpy.array(npzfile['data_std'])
if not self.data_std.shape: # if std was saved as scalar
self.data_std = numpy.ones(features_length)
# if self.data_mean.shape[0] > input_shape[0]:
......@@ -52,6 +55,7 @@ class LSTMEval(Algorithm):
# self.data_std = self.data_std[:input_shape[0]]
# self.data_std = numpy.reshape(self.data_std, input_shape)
else:
logger.warn("Normalization file '%s' does not exist!" % normalization_file)
self.data_mean = 0
self.data_std = 1
......@@ -60,29 +64,29 @@ class LSTMEval(Algorithm):
self.dnn_model = None
self.data_placeholder = None
def simple_lstm_network(self, train_data_shuffler, batch_size=10, lstm_cell_size=64,
num_time_steps=28, num_classes=10, seed=10, reuse=False):
import tensorflow as tf
from bob.learn.tensorflow.layers import lstm
slim = tf.contrib.slim
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# Creating an LSTM network
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm', name='sync_cell',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse)
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, reuse=reuse)
return graph
# def simple_lstm_network(self, train_data_shuffler, batch_size=10, lstm_cell_size=64,
# num_time_steps=28, num_classes=10, seed=10, reuse=False):
# import tensorflow as tf
# from bob.learn.tensorflow.layers import lstm
# slim = tf.contrib.slim
#
# if isinstance(train_data_shuffler, tf.Tensor):
# inputs = train_data_shuffler
# else:
# inputs = train_data_shuffler("data", from_queue=False)
#
# initializer = tf.contrib.layers.xavier_initializer(seed=seed)
#
# # Creating an LSTM network
# graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
# output_activation_size=num_classes, scope='lstm', name='sync_cell',
# weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse)
#
# # fully connect the LSTM output to the classes
# graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
# weights_initializer=initializer, reuse=reuse)
#
# return graph
def normalize_data(self, features):
mean = numpy.mean(features, axis=0)
......@@ -97,16 +101,34 @@ class LSTMEval(Algorithm):
def restore_trained_model(self, projector_file):
import tensorflow as tf
from bob.learn.tensorflow.network import LightCNN9
if self.session is None:
self.session = tf.Session()
data_pl = tf.placeholder(tf.float32, shape=(None,) + tuple(self.input_shape), name="data")
# network = LightCNN9(n_classes=2, device="/cpu:0")
# graph = network(data_pl)
graph = self.simple_lstm_network(data_pl, batch_size=1,
lstm_cell_size=self.lstm_network_size, num_time_steps=self.num_time_steps,
num_classes=2, reuse=False)
# create an empty graph of the correct architecture but with needed batch_size==1
if self.architecture_name == 'lstm':
from bob.learn.tensorflow.network import simple_lstm_network
graph = simple_lstm_network(data_pl, batch_size=1,
lstm_cell_size=self.network_size, num_time_steps=self.num_time_steps,
num_classes=2, reuse=False)
elif self.architecture_name == 'mlp':
from bob.learn.tensorflow.network import mlp_network
graph = mlp_network(data_pl,
hidden_layer_size=self.network_size,
num_time_steps=self.num_time_steps,
num_classes=2, reuse=False)
elif self.architecture_name == 'simplecnn':
from bob.learn.tensorflow.network import simple2Dcnn_network
graph = simple2Dcnn_network(data_pl,
num_classes=2, reuse=False)
elif self.architecture_name == 'lightcnn':
from bob.learn.tensorflow.network import LightCNN9
network = LightCNN9(n_classes=2, device="/cpu:0")
graph = network(data_pl, reuse=False)
else:
return None
self.session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
......@@ -201,4 +223,4 @@ class LSTMEval(Algorithm):
return [toscore[0]]
algorithm = LSTMEval()
algorithm = TensorflowEval()
from .GMM import GMM
from .LogRegr import LogRegr
from .LSTMEval import LSTMEval
from .TensorflowEval import TensorflowEval
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
......@@ -19,6 +19,6 @@ def __appropriate__(*args):
__appropriate__(
GMM,
LogRegr,
LSTMEval,
TensorflowEval,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -116,7 +116,7 @@ setup(
],
'bob.pad.algorithm': [
'tensorflow = bob.pad.voice.algorithm.LSTMEval:algorithm',
'tensorflow = bob.pad.voice.algorithm.TensorflowEval:algorithm',
'dummy-algo = bob.pad.voice.algorithm.dummy:algorithm',
# compute scores based on different energy bands
'logregr = bob.pad.voice.algorithm.LogRegr:algorithm',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment