Commit cea41d17 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed small bug with the shape

parent 5d085e94
Pipeline #21864 passed with stage
in 21 minutes and 22 seconds
...@@ -168,10 +168,10 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -168,10 +168,10 @@ def style_transfer(content_image_path, output_path, style_image_paths,
""" """
# Reading and converting to the tensorflow format # Reading and converting to the tensorflow format
content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image_path)) content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image_path)).astype("float32")
style_images = [] style_images = []
for path in style_image_paths: for path in style_image_paths:
image = bob.io.image.to_matplotlib(bob.io.base.load(path)) image = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype("float32")
style_images.append(numpy.reshape(image, wise_shape(image.shape))) style_images.append(numpy.reshape(image, wise_shape(image.shape)))
# Reshaping to NxWxHxC # Reshaping to NxWxHxC
......
...@@ -26,12 +26,11 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point ...@@ -26,12 +26,11 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
Dictionary containing the end point tensors Dictionary containing the end point tensors
""" """
input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1], input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1],
input_image.shape[2], input_image.shape[2],
input_image.shape[3])) input_image.shape[3]))
# TODO: Think on how abstract this normalization operation # TODO: Think on how abstract this normalization operation
_, end_points = architecture(tf.stack([tf.image.per_image_standardization(i) for i in tf.unstack(input_image)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None) _, end_points = architecture(tf.stack([tf.image.per_image_standardization(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
with tf.Session() as sess: with tf.Session() as sess:
# Restoring the checkpoint for the given architecture # Restoring the checkpoint for the given architecture
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment