From e34b58c07e76efe64ebf81732f02740087c49d32 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Wed, 26 Oct 2016 21:55:29 +0200 Subject: [PATCH] Shape --- bob/learn/tensorflow/network/SequenceNetwork.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bob/learn/tensorflow/network/SequenceNetwork.py b/bob/learn/tensorflow/network/SequenceNetwork.py index 8a494af9..d7552971 100644 --- a/bob/learn/tensorflow/network/SequenceNetwork.py +++ b/bob/learn/tensorflow/network/SequenceNetwork.py @@ -220,7 +220,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): hdf5.set('input_divide', self.input_divide) hdf5.set('input_subtract', self.input_subtract) - def load(self, hdf5, shape, session=None): + def load(self, hdf5, shape=None, session=None): """ Load the network @@ -244,6 +244,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): self.sequence_net = pickle.loads(hdf5.read('architecture')) self.deployment_shape = hdf5.read('deployment_shape') + if shape is None: + shape = self.deployment_shape + # Loading variables place_holder = tf.placeholder(tf.float32, shape=shape, name="load") self.compute_graph(place_holder) -- GitLab