Skip to content
Snippets Groups Projects
Commit cda150df authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

add starting point options to the style transfer script

parent 63993d46
No related branches found
No related tags found
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) ...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
@click.option('--style-image-paths', @click.option('--style-image-paths',
cls=ResourceOption, cls=ResourceOption,
required=True, required=True,
multiple=True, multiple=True,
entry_point_group='bob.learn.tensorflow.style_images', entry_point_group='bob.learn.tensorflow.style_images',
help='List of images that encodes the style.') help='List of images that encodes the style.')
@click.option('--architecture', @click.option('--architecture',
...@@ -95,13 +95,21 @@ logger = logging.getLogger(__name__) ...@@ -95,13 +95,21 @@ logger = logging.getLogger(__name__)
cls=ResourceOption, cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.preprocess_fn', entry_point_group='bob.learn.tensorflow.preprocess_fn',
help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal') help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal')
@click.option(
'--start-from',
'-sf',
cls=ResourceOption,
default="noise",
type=click.Choice(["noise", "content", "style"]),
help="Starts from this image for reconstruction",
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def style_transfer(content_image_path, output_path, style_image_paths, def style_transfer(content_image_path, output_path, style_image_paths,
architecture, checkpoint_dir, architecture, checkpoint_dir,
iterations, learning_rate, iterations, learning_rate,
content_weight, style_weight, denoise_weight, content_end_points, content_weight, style_weight, denoise_weight, content_end_points,
style_end_points, scopes, pure_noise, preprocess_fn, style_end_points, scopes, pure_noise, preprocess_fn,
un_preprocess_fn, **kwargs): un_preprocess_fn, start_from, **kwargs):
""" """
Trains neural style transfer using the approach presented in: Trains neural style transfer using the approach presented in:
...@@ -112,7 +120,7 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -112,7 +120,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
If you want run a style transfer using InceptionV2 as basis, use the following template If you want run a style transfer using InceptionV2 as basis, use the following template
Below follow a CONFIG template Below follow a CONFIG template
CONFIG.PY CONFIG.PY
``` ```
...@@ -159,7 +167,7 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -159,7 +167,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
"STYLE_2.png"] "STYLE_2.png"]
``` ```
Then run:: Then run::
$ bob tf style <content-image> <output-image> CONFIG.py $ bob tf style <content-image> <output-image> CONFIG.py
...@@ -178,14 +186,14 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -178,14 +186,14 @@ def style_transfer(content_image_path, output_path, style_image_paths,
for path in style_image_paths: for path in style_image_paths:
style_images.append(bob.io.base.load(path)) style_images.append(bob.io.base.load(path))
output = do_style_transfer(content_image, style_images, output = do_style_transfer(content_image, style_images,
architecture, checkpoint_dir, scopes, architecture, checkpoint_dir, scopes,
content_end_points, style_end_points, content_end_points, style_end_points,
preprocess_fn=preprocess_fn, un_preprocess_fn=un_preprocess_fn, preprocess_fn=preprocess_fn, un_preprocess_fn=un_preprocess_fn,
pure_noise=pure_noise, pure_noise=pure_noise,
iterations=iterations, learning_rate=learning_rate, iterations=iterations, learning_rate=learning_rate,
content_weight=content_weight, style_weight=style_weight, content_weight=content_weight, style_weight=style_weight,
denoise_weight=denoise_weight) denoise_weight=denoise_weight, start_from=start_from)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
bob.io.base.save(output, output_path) bob.io.base.save(output, output_path)
...@@ -57,10 +57,7 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point ...@@ -57,10 +57,7 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
saver.restore(sess, checkpoint_dir) saver.restore(sess, checkpoint_dir)
#content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image}) #content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image})
features = [] features = sess.run([end_points[ep] for ep in target_end_points], feed_dict={input_pl: input_image})
for ep in target_end_points:
feature = sess.run(end_points[ep], feed_dict={input_pl: input_image})
features.append(feature)
# Killing the graph # Killing the graph
tf.reset_default_graph() tf.reset_default_graph()
...@@ -95,7 +92,7 @@ def do_style_transfer(content_image, style_images, ...@@ -95,7 +92,7 @@ def do_style_transfer(content_image, style_images,
content_end_points, style_end_points, content_end_points, style_end_points,
preprocess_fn=None, un_preprocess_fn=None, pure_noise=False, preprocess_fn=None, un_preprocess_fn=None, pure_noise=False,
iterations=1000, learning_rate=0.1, iterations=1000, learning_rate=0.1,
content_weight=5., style_weight=500., denoise_weight=500.): content_weight=5., style_weight=500., denoise_weight=500., start_from="noise"):
""" """
Trains neural style transfer using the approach presented in: Trains neural style transfer using the approach presented in:
...@@ -192,8 +189,16 @@ def do_style_transfer(content_image, style_images, ...@@ -192,8 +189,16 @@ def do_style_transfer(content_image, style_images,
tf.set_random_seed(0) tf.set_random_seed(0)
# Random noise # Random noise
noise = tf.Variable(tf.random_normal(shape=content_image.shape), if start_from == "noise":
trainable=True) * 0.256 starting_image = tf.random_normal(shape=content_image.shape) * 0.256
elif start_from == "content":
starting_image = preprocess_fn(content_image)
elif start_from == "style":
starting_image = preprocess_fn(style_images[0])
else:
raise ValueError(f"Unknown starting image: {start_from}")
noise = tf.Variable(starting_image, dtype="float32", trainable=True)
_, end_points = architecture(noise, _, end_points = architecture(noise,
mode=tf.estimator.ModeKeys.PREDICT, mode=tf.estimator.ModeKeys.PREDICT,
trainable_variables=[]) trainable_variables=[])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment