diff --git a/bob/learn/tensorflow/network/SimpleCNN.py b/bob/learn/tensorflow/network/SimpleCNN.py
index bb4626bd610cfb16c07c23513340a823db2d87ff..eb7a98a8c61b89f8cd745dfc5e07c2ec1583a7fd 100644
--- a/bob/learn/tensorflow/network/SimpleCNN.py
+++ b/bob/learn/tensorflow/network/SimpleCNN.py
@@ -47,8 +47,8 @@ def architecture(input_layer, mode=tf.estimator.ModeKeys.TRAIN,
                                     data_format=data_format)
 
     # Flatten tensor into a batch of vectors
-    dim = tf.reduce_prod(tf.shape(pool2)[1:])
-    pool2_flat = tf.reshape(pool2, [-1, dim])
+    # TODO: use tf.layers.flatten in tensorflow 1.4 above
+    pool2_flat = tf.contrib.layers.flatten(pool2)
 
     # Dense Layer
     # Densely connected layer with 1024 neurons