Skip to content
Snippets Groups Projects
Commit 202ac4ab authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

make sure can work with hdf5 models

parent 287b1b73
No related branches found
No related tags found
1 merge request!6Merge branch with audio-stuff into master
......@@ -120,11 +120,11 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
# Feeding the placeholder
if self.inference_placeholder is None:
self.compute_inference_placeholder(data.shape[1:])
self.compute_inference_placeholder([None] + list(data.shape[1:]))
feed_dict = {self.inference_placeholder: data}
if self.inference_graph is None:
self.compute_inference_graph(self.inference_placeholder, feature_layer)
self.compute_inference_graph(feature_layer)
embedding = session.run([self.inference_graph], feed_dict=feed_dict)[0]
......@@ -266,7 +266,8 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
hdf5.cd('/tensor_flow')
for k in self.sequence_net:
# TODO: IT IS NOT SMART TESTING ALONG THIS PAGE
if not isinstance(self.sequence_net[k], MaxPooling):
if not isinstance(self.sequence_net[k], MaxPooling) and \
not isinstance(self.sequence_net[k], LogSoftMax):
self.sequence_net[k].W.assign(hdf5.read(self.sequence_net[k].W.name)).eval(session=session)
session.run(self.sequence_net[k].W)
self.sequence_net[k].b.assign(hdf5.read(self.sequence_net[k].b.name)).eval(session=session)
......@@ -300,11 +301,11 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
# Loading architecture
self.sequence_net = pickle.loads(hdf5.read('architecture'))
self.deployment_shape = hdf5.read('deployment_shape')
self.turn_gpu_onoff(use_gpu)
if shape is None:
self.deployment_shape = hdf5.read('deployment_shape')
shape = self.deployment_shape
shape[0] = batch
......@@ -312,7 +313,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
place_holder = tf.placeholder(tf.float32, shape=shape, name="load")
self.compute_graph(place_holder)
tf.initialize_all_variables().run(session=session)
self.load_variables_only(hdf5, session)
self.load_variables_only(hdf5)
def save(self, saver, path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment