Skip to content
Snippets Groups Projects
Commit f6d1eba1 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[layers] fixed the concatenation of the conditional variable

parent 735979b1
No related branches found
No related tags found
1 merge request!8Gan
......@@ -46,7 +46,5 @@ class ConditionConcat(Layer):
def get_graph(self, y):
batch_size = input_layer.get_shape()[0]
yb = tf.reshape(y, [batch_size, 1, 1, self.conditional_dim])
return tf.concat([input_layer, y], 1)
return tf.concat([self.input_layer, y], 1)
......@@ -46,6 +46,8 @@ class ImToCondFeatureMap(Layer):
use_gpu=False
)
self.conditional_dim = conditional_dim
logger.info("+ adding a Concatenation layer ({0}) +".format(name))
logger.info("\t conditional dimension = {0}".format(conditional_dim))
......@@ -61,9 +63,11 @@ class ImToCondFeatureMap(Layer):
name: y
The conditioning vector (one-hot encoded)
"""
# concatenate them to the image
x_shapes = input_layer.get_shape()
y_shapes = y.get_shape()
return tf.concat([input_layer, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
"""
# concatenate to the image
batch_size = tf.shape(self.input_layer)[0]
yb = tf.reshape(y, [batch_size, 1, 1, self.conditional_dim])
x_shapes = self.input_layer.get_shape()
y_shapes = yb.get_shape()
return tf.concat([self.input_layer, yb*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment