diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py index bb4626bd610cfb16c07c23513340a823db2d87ff..eb7a98a8c61b89f8cd745dfc5e07c2ec1583a7fd 100644 --- a/bob/learn/tensorflow/network/SimpleCNN.py +++ b/bob/learn/tensorflow/network/SimpleCNN.py @@ -47,8 +47,8 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN, data_format=data_format) # Flatten tensor into a batch of vectors - dim = tf.reduce_prod(tf.shape(pool2)[1:]) - pool2_flat = tf.reshape(pool2, [-1, dim]) + # TODO: use tf.layers.flatten in tensorflow 1.4 above + pool2_flat = tf.contrib.layers.flatten(pool2) # Dense Layer # Densely connected layer with 1024 neurons