Commit b0e3aa49 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV

minor fix of tf evaluation

parent e88c94f7
...@@ -104,7 +104,8 @@ class TensorflowEval(Algorithm): ...@@ -104,7 +104,8 @@ class TensorflowEval(Algorithm):
if self.session is None: if self.session is None:
self.session = tf.Session() self.session = tf.Session()
data_pl = tf.placeholder(tf.float32, shape=(None,) + tuple(self.input_shape), name="data") # add extra dimension to the input, so that 2D convolution would work
data_pl = tf.placeholder(tf.float32, shape=(None,) + tuple(self.input_shape) + (1,), name="data")
# create an empty graph of the correct architecture but with needed batch_size==1 # create an empty graph of the correct architecture but with needed batch_size==1
if self.architecture_name == 'lstm': if self.architecture_name == 'lstm':
...@@ -166,7 +167,9 @@ class TensorflowEval(Algorithm): ...@@ -166,7 +167,9 @@ class TensorflowEval(Algorithm):
projections = numpy.zeros((len(frames), 2), dtype=numpy.float32) projections = numpy.zeros((len(frames), 2), dtype=numpy.float32)
for i in range(frames.shape[0]): for i in range(frames.shape[0]):
frame = frames[i] frame = frames[i]
frame = numpy.reshape(frame, [1] + self.input_shape) # reshape to 4D shape, so that all networks, including CNN-based
# would work propery
frame = numpy.reshape(frame, [1] + self.input_shape + [1])
#logger.info(" .... projecting frame of shape {0} onto DNN model".format(frame.shape)) #logger.info(" .... projecting frame of shape {0} onto DNN model".format(frame.shape))
if self.session is not None: if self.session is not None:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment