diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
index 51f99670608a4369779e01f2eea73124ff7593f0..efa367f135c2073db46742c750be9a459b27fe01 100644
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -329,18 +329,12 @@ def blocks_tensorflow(images, block_size):
block_size = [1] + list(block_size) + [1]
output_size = list(block_size)
output_size[0] = -1
- # extract image patches for each color space:
- output = []
- for i in range(3):
- blocks = tf.extract_image_patches(
- images[:, :, :, i : i + 1], block_size, block_size, [1, 1, 1, 1], "VALID"
- )
- if i == 0:
- n_blocks = int(numpy.prod(blocks.shape[1:3]))
- blocks = tf.reshape(blocks, output_size)
- output.append(blocks)
- # concatenate the colors back
- output = tf.concat(output, axis=3)
+ output_size[-1] = images.shape[-1]
+ blocks = tf.extract_image_patches(
+ images, block_size, block_size, [1, 1, 1, 1], "VALID"
+ )
+ n_blocks = int(numpy.prod(blocks.shape[1:3]))
+ output = tf.reshape(blocks, output_size)
return output, n_blocks