Commit 89b2565b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented the preprocess_fn and un_preprocess_fn

parent 60cf788d
"""
Example using inception resnet v1
"""
import tensorflow as tf
# -- architecture
from bob.learn.tensorflow.network import inception_resnet_v1_batch_norm
architecture = inception_resnet_v1_batch_norm
# --checkpoint-dir
from bob.extension import rc
checkpoint_dir = rc['bob.bio.face_ongoing.inception-v1_batchnorm_rgb']
# --style-end-points and -- content-end-points
style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3"]
content_end_points = ["Block8"]
scopes = {"InceptionResnetV1/":"InceptionResnetV1/"}
# --style-image-paths
style_image_paths = ["vincent_van_gogh.jpg",
"vincent_van_gogh2.jpg"]
# --preprocess-fn
preprocess_fn = tf.image.per_image_standardization
"""
Example using inception resnet v2
"""
import tensorflow as tf
# -- architecture
from bob.learn.tensorflow.network import inception_resnet_v2_batch_norm
architecture = inception_resnet_v2_batch_norm
# --checkpoint-dir
from bob.extension import rc
checkpoint_dir = rc['bob.bio.face_ongoing.inception-v2_batchnorm_rgb']
# --style-end-points and -- content-end-points
style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3"]
content_end_points = ["Block8"]
scopes = {"InceptionResnetV2/":"InceptionResnetV2/"}
# --style-image-paths
style_image_paths = ["vincent_van_gogh.jpg",
"vincent_van_gogh2.jpg"]
# --preprocess-fn
preprocess_fn = tf.image.per_image_standardization
"""
Example using VGG19
"""
from bob.learn.tensorflow.network import vgg_19
# --architecture
architecture = vgg_19
import numpy
# -- checkpoint-dir
# YOU CAN DOWNLOAD THE CHECKPOINTS FROM HERE
# https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
checkpoint_dir = "[VGG-19-CHECKPOINT]"
# --style-end-points and -- content-end-points
content_end_points = ['vgg_19/conv4/conv4_2', 'vgg_19/conv5/conv5_2']
style_end_points = ['vgg_19/conv1/conv1_2',
'vgg_19/conv2/conv2_1',
'vgg_19/conv3/conv3_1',
'vgg_19/conv4/conv4_1',
'vgg_19/conv5/conv5_1'
]
scopes = {"vgg_19/":"vgg_19/"}
style_image_paths = ["vincent_van_gogh.jpg",
"vincent_van_gogh2.jpg"]
# --preprocess-fn and --un-preprocess-fn
# Taken from VGG19
def mean_norm(tensor):
return tensor - numpy.array([ 123.68 , 116.779, 103.939])
def un_mean_norm(tensor):
return tensor + numpy.array([ 123.68 , 116.779, 103.939])
preprocess_fn = mean_norm
un_preprocess_fn = un_mean_norm
......@@ -66,11 +66,11 @@ def normalize4save(img):
@click.option('--content-weight',
type=click.types.FLOAT,
help='Weight of the content loss.',
default=1.)
default=5.)
@click.option('--style-weight',
type=click.types.FLOAT,
help='Weight of the style loss.',
default=1000.)
default=100.)
@click.option('--denoise-weight',
type=click.types.FLOAT,
help='Weight denoising loss.',
......@@ -95,12 +95,23 @@ def normalize4save(img):
help="If set will save the raw noisy generated image."
"If not set, the output will be RGB = stylizedYUV.Y, originalYUV.U, originalYUV.V"
)
@click.option('--preprocess-fn',
'-pr',
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.preprocess_fn',
help='Preprocess function. Pointer to a function that preprocess the INPUT signal')
@click.option('--un-preprocess-fn',
'-un',
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.preprocess_fn',
help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal')
@verbosity_option(cls=ResourceOption)
def style_transfer(content_image_path, output_path, style_image_paths,
architecture, checkpoint_dir,
iterations, learning_rate,
content_weight, style_weight, denoise_weight, content_end_points,
style_end_points, scopes, pure_noise, **kwargs):
style_end_points, scopes, pure_noise, preprocess_fn,
un_preprocess_fn, **kwargs):
"""
Trains neural style transfer using the approach presented in:
......@@ -179,13 +190,15 @@ def style_transfer(content_image_path, output_path, style_image_paths,
# Base content features
logger.info("Computing content features")
content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points)
content_features = compute_features(content_image, architecture, checkpoint_dir,
content_end_points, preprocess_fn)
# Base style features
logger.info("Computing style features")
style_grams = []
for image in style_images:
style_features = compute_features(image, architecture, checkpoint_dir, style_end_points)
style_features = compute_features(image, architecture, checkpoint_dir,
style_end_points, preprocess_fn)
style_grams.append(compute_gram(style_features))
# Organizing the trainer
......@@ -195,8 +208,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
# Random noise
noise = tf.Variable(tf.random_normal(shape=content_image.shape),
trainable=True)
trainable=True) * 0.256
_, end_points = architecture(noise,
mode=tf.estimator.ModeKeys.PREDICT,
trainable_variables=[])
......@@ -240,6 +252,10 @@ def style_transfer(content_image_path, output_path, style_image_paths,
# Saving generated image
raw_style_image = sess.run(noise)[0, :, :,:]
# Unpreprocessing the signal
if un_preprocess_fn is not None:
raw_style_image = un_preprocess_fn(raw_style_image)
raw_style_image = bob.io.image.to_bob(raw_style_image)
normalized_style_image = normalize4save(raw_style_image)
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
import numpy
import os
def compute_features(input_image, architecture, checkpoint_dir, target_end_points):
def compute_features(input_image, architecture, checkpoint_dir, target_end_points, preprocess_fn=None):
"""
For a given set of end_points, convolve the input image until these points
......@@ -26,15 +26,19 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
end_points: dict
Dictionary containing the end point tensors
preprocess_fn:
Pointer to a preprocess function
"""
input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1],
input_image.shape[2],
input_image.shape[3]))
# TODO: Think on how abstract this normalization operation
_, end_points = architecture(tf.stack([tf.image.per_image_standardization(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
if preprocess_fn is None:
_, end_points = architecture(input_pl, mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
else:
_, end_points = architecture(tf.stack([preprocess_fn(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
with tf.Session() as sess:
# Restoring the checkpoint for the given architecture
sess.run(tf.global_variables_initializer())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment