From ccd8bca8fa66357b0931a0c2ac472c25edcba279 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 27 May 2019 10:24:32 +0200
Subject: [PATCH] improve block extraction

---
 bob/learn/tensorflow/dataset/__init__.py | 18 ++++++------------
 1 file changed, 6 insertions(+), 12 deletions(-)

diff --git a/bob/learn/tensorflow/dataset/__init__.py b/bob/learn/tensorflow/dataset/__init__.py
index 51f99670..efa367f1 100644
--- a/bob/learn/tensorflow/dataset/__init__.py
+++ b/bob/learn/tensorflow/dataset/__init__.py
@@ -329,18 +329,12 @@ def blocks_tensorflow(images, block_size):
     block_size = [1] + list(block_size) + [1]
     output_size = list(block_size)
     output_size[0] = -1
-    # extract image patches for each color space:
-    output = []
-    for i in range(3):
-        blocks = tf.extract_image_patches(
-            images[:, :, :, i : i + 1], block_size, block_size, [1, 1, 1, 1], "VALID"
-        )
-        if i == 0:
-            n_blocks = int(numpy.prod(blocks.shape[1:3]))
-        blocks = tf.reshape(blocks, output_size)
-        output.append(blocks)
-    # concatenate the colors back
-    output = tf.concat(output, axis=3)
+    output_size[-1] = images.shape[-1]
+    blocks = tf.extract_image_patches(
+        images, block_size, block_size, [1, 1, 1, 1], "VALID"
+    )
+    n_blocks = int(numpy.prod(blocks.shape[1:3]))
+    output = tf.reshape(blocks, output_size)
     return output, n_blocks
 
 
-- 
GitLab