Commit e5a29aa5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Improve patch extractor

parent 30dd85ea
......@@ -395,7 +395,18 @@ def all_patches(image, label, key, size):
blocks, n_blocks = blocks_tensorflow(image, size)
# duplicate label and key as n_blocks
label = tf_repeat(label, [n_blocks])
key = tf_repeat(key, [n_blocks])
def repeats(shape):
r = shape.as_list()
for i in range(len(r)):
if i == 0:
r[i] = n_blocks
r[i] = 1
return r
label = tf_repeat(label, repeats(label.shape))
key = tf_repeat(key, repeats(key.shape))
return blocks, label, key
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment