Skip to content
Snippets Groups Projects
Commit ee40233a authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'fix-style-transfer' into 'master'

[style] Fixed issues with style-transfer

Closes #82

See merge request !80
parents 7e18ff07 c38c3ef6
No related branches found
No related tags found
1 merge request!80[style] Fixed issues with style-transfer
Pipeline #32531 failed
......@@ -5,6 +5,7 @@
import logging
import tensorflow as tf
logger = logging.getLogger("bob.learn.tensorflow")
import functools
def content_loss(noises, content_features):
......@@ -33,7 +34,7 @@ def content_loss(noises, content_features):
content_losses = []
for n,c in zip(noises, content_features):
content_losses.append((2 * tf.nn.l2_loss(n - c) / c.size))
return reduce(tf.add, content_losses)
return functools.reduce(tf.add, content_losses)
def linear_gram_style_loss(noises, gram_style_features):
......@@ -63,7 +64,7 @@ def linear_gram_style_loss(noises, gram_style_features):
for n,s in zip(noises, gram_style_features):
style_losses.append((2 * tf.nn.l2_loss(n - s)) / s.size)
return reduce(tf.add, style_losses)
return functools.reduce(tf.add, style_losses)
......@@ -82,7 +83,7 @@ def denoising_loss(noise):
"""
def _tensor_size(tensor):
from operator import mul
return reduce(mul, (d.value for d in tensor.get_shape()), 1)
return functools.reduce(mul, (d.value for d in tensor.get_shape()), 1)
shape = noise.get_shape().as_list()
......
......@@ -49,7 +49,7 @@ logger = logging.getLogger(__name__)
help='Number of iterations to generate the image',
default=1000)
@click.option('--learning_rate',
'-i',
'-r',
type=click.types.FLOAT,
help='Learning rate.',
default=1.)
......
......@@ -163,6 +163,7 @@ def do_style_transfer(content_image, style_images,
# Reshaping to NxWxHxC and converting to the tensorflow format
# content
original_image = content_image
content_image = bob.io.image.to_matplotlib(content_image).astype("float32")
content_image = numpy.reshape(content_image, wise_shape(content_image.shape))
......@@ -253,11 +254,12 @@ def do_style_transfer(content_image, style_images,
if normalized_style_image.shape[0] == 1:
normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(normalized_style_image[0,:,:]))
# Loading the content image and clipping from 0-255 in case is in another scale
scaled_content_image = normalize4save(bob.io.base.load(content_image_path).astype("float32")).astype("float64")
#scaled_content_image = normalize4save(bob.io.base.load(content_image_path).astype("float32")).astype("float64")
scaled_content_image = original_image.astype("float64")
content_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(scaled_content_image))
else:
normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image)))
content_image_yuv = bob.ip.color.rgb_to_yuv(bob.io.base.load(content_image_path))
content_image_yuv = bob.ip.color.rgb_to_yuv(original_image)
output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8")
output_image[0,:,:] = normalized_style_image_yuv[0,:,:]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment