Commit e6799f2c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Enabled multiple style images

parent fffc4590
...@@ -25,11 +25,19 @@ logger = logging.getLogger(__name__) ...@@ -25,11 +25,19 @@ logger = logging.getLogger(__name__)
def normalize4save(img): def normalize4save(img):
return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8") return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8")
#@click.argument('style_image_path', required=True)
@click.command( @click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.argument('content_image', required=True) @click.argument('content_image_path', required=True)
@click.argument('style_image', required=True) @click.argument('output_path', required=True)
@click.argument('output', required=True) @click.option('--style-image-paths',
cls=ResourceOption,
required=True,
multiple=True,
entry_point_group='bob.learn.tensorflow.style_images',
help='List of images that encods the style.')
@click.option('--architecture', @click.option('--architecture',
'-a', '-a',
required=True, required=True,
...@@ -78,42 +86,52 @@ def normalize4save(img): ...@@ -78,42 +86,52 @@ def normalize4save(img):
entry_point_group='bob.learn.tensorflow.scopes', entry_point_group='bob.learn.tensorflow.scopes',
help='Dictionary containing the mapping scores', help='Dictionary containing the mapping scores',
required=True) required=True)
@click.option('--pure-noise',
is_flag=True,
help="If set will save the raw noisy generated image."
"If not set, the output will be RGB = stylizedYUV.Y, originalYUV.U, originalYUV.V"
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def style_transfer(content_image, style_image, output, def style_transfer(content_image_path, output_path, style_image_paths,
architecture, checkpoint_dir, architecture, checkpoint_dir,
iterations, learning_rate, iterations, learning_rate,
content_weight, style_weight, denoise_weight, content_end_points, content_weight, style_weight, denoise_weight, content_end_points,
style_end_points, scopes, **kwargs): style_end_points, scopes, pure_noise, **kwargs):
""" """
Trains neural style transfer from Trains neural style transfer from
Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015). Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).
""" """
# Reading and converting to the tensorflow format # Reading and converting to the tensorflow format
content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image)) content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image_path))
style_image = bob.io.image.to_matplotlib(bob.io.base.load(style_image)) style_images = []
for path in style_image_paths:
image = bob.io.image.to_matplotlib(bob.io.base.load(path))
style_images.append(numpy.reshape(image, (1, image.shape[0],
image.shape[1],
image.shape[2])))
# Reshaping to NxWxHxC # Reshaping to NxWxHxC
content_image = numpy.reshape(content_image, (1, content_image.shape[0], content_image = numpy.reshape(content_image, (1, content_image.shape[0],
content_image.shape[1], content_image.shape[1],
content_image.shape[2])) content_image.shape[2]))
style_image = numpy.reshape(style_image, (1, style_image.shape[0],
style_image.shape[1],
style_image.shape[2]))
# Base content features # Base content features
content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points) content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points)
# Base style features # Base style features
# TODO: Enable a set of style images # TODO: Enable a set of style images
style_features = compute_features(style_image, architecture, checkpoint_dir, style_end_points) style_grams = []
style_grams = compute_gram(style_features) for image in style_images:
style_features = compute_features(image, architecture, checkpoint_dir, style_end_points)
style_grams.append(compute_gram(style_features))
# Organizing the trainer # Organizing the trainer
with tf.Graph().as_default(): with tf.Graph().as_default():
tf.set_random_seed(0)
# Random noise # Random noise
noise = tf.Variable(tf.random_normal(shape=content_image.shape), noise = tf.Variable(tf.random_normal(shape=content_image.shape),
trainable=True) trainable=True)
...@@ -130,18 +148,20 @@ def style_transfer(content_image, style_image, output, ...@@ -130,18 +148,20 @@ def style_transfer(content_image, style_image, output,
# Computing style_loss # Computing style_loss
style_gram_noises = [] style_gram_noises = []
for c in style_end_points: s_loss = 0
layer = end_points[c] for grams_per_image in style_grams:
_, height, width, number = map(lambda i: i.value, layer.get_shape())
size = height * width * number for c in style_end_points:
features = tf.reshape(layer, (-1, number)) layer = end_points[c]
style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size) _, height, width, number = map(lambda i: i.value, layer.get_shape())
s_loss = linear_gram_style_loss(style_gram_noises, style_grams) size = height * width * number
features = tf.reshape(layer, (-1, number))
style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size)
s_loss += linear_gram_style_loss(style_gram_noises, grams_per_image)
# Variation denoise # Variation denoise
d_loss = denoising_loss(noise) d_loss = denoising_loss(noise)
#Total loss #Total loss
total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss
...@@ -149,7 +169,7 @@ def style_transfer(content_image, style_image, output, ...@@ -149,7 +169,7 @@ def style_transfer(content_image, style_image, output,
tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir), tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir),
scopes) scopes)
# Training
with tf.Session() as sess: with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
...@@ -158,11 +178,24 @@ def style_transfer(content_image, style_image, output, ...@@ -158,11 +178,24 @@ def style_transfer(content_image, style_image, output,
print("Iteration {0}, loss {1}".format(i, loss)) print("Iteration {0}, loss {1}".format(i, loss))
sys.stdout.flush() sys.stdout.flush()
style_image = sess.run(noise)[0, :, :,:] # Saving generated image
style_image = bob.io.image.to_bob(style_image) raw_style_image = sess.run(noise)[0, :, :,:]
bob.io.base.save(normalize4save(style_image), output) raw_style_image = bob.io.image.to_bob(raw_style_image)
normalized_style_image = normalize4save(raw_style_image)
if pure_noise:
bob.io.base.save(normalized_style_image, output_path)
else:
# Original output
normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image)))
content_image_yuv = bob.ip.color.rgb_to_yuv(bob.io.base.load(content_image_path))
output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8")
output_image[0,:,:] = normalized_style_image_yuv[0,:,:]
output_image[1,:,:] = content_image_yuv[1,:,:]
output_image[2,:,:] = content_image_yuv[2,:,:]
output_image = bob.ip.color.yuv_to_rgb(output_image)
bob.io.base.save(output_image, output_path)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment