diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py index 51f99670608a4369779e01f2eea73124ff7593f0..efa367f135c2073db46742c750be9a459b27fe01 100644 --- a/bob/learn/tensorflow/dataset/__init__.py +++ b/bob/learn/tensorflow/dataset/__init__.py @@ -329,18 +329,12 @@ def blocks_tensorflow(images, block_size): block_size = [1] + list(block_size) + [1] output_size = list(block_size) output_size[0] = -1 - # extract image patches for each color space: - output = [] - for i in range(3): - blocks = tf.extract_image_patches( - images[:, :, :, i : i + 1], block_size, block_size, [1, 1, 1, 1], "VALID" - ) - if i == 0: - n_blocks = int(numpy.prod(blocks.shape[1:3])) - blocks = tf.reshape(blocks, output_size) - output.append(blocks) - # concatenate the colors back - output = tf.concat(output, axis=3) + output_size[-1] = images.shape[-1] + blocks = tf.extract_image_patches( + images, block_size, block_size, [1, 1, 1, 1], "VALID" + ) + n_blocks = int(numpy.prod(blocks.shape[1:3])) + output = tf.reshape(blocks, output_size) return output, n_blocks