diff --git a/bob/learn/tensorflow/network/SequenceNetwork.py b/bob/learn/tensorflow/network/SequenceNetwork.py index 8a494af9623b246ae434e6d6082d40d025f105d0..d755297174020fe21c27508e3267a8b14133c4c6 100644 --- a/bob/learn/tensorflow/network/SequenceNetwork.py +++ b/bob/learn/tensorflow/network/SequenceNetwork.py @@ -220,7 +220,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): hdf5.set('input_divide', self.input_divide) hdf5.set('input_subtract', self.input_subtract) - def load(self, hdf5, shape, session=None): + def load(self, hdf5, shape=None, session=None): """ Load the network @@ -244,6 +244,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)): self.sequence_net = pickle.loads(hdf5.read('architecture')) self.deployment_shape = hdf5.read('deployment_shape') + if shape is None: + shape = self.deployment_shape + # Loading variables place_holder = tf.placeholder(tf.float32, shape=shape, name="load") self.compute_graph(place_holder)