Commit cda150df authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

add starting point options to the style transfer script

parent 63993d46
...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) ...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
@click.option('--style-image-paths', @click.option('--style-image-paths',
cls=ResourceOption, cls=ResourceOption,
required=True, required=True,
multiple=True, multiple=True,
entry_point_group='bob.learn.tensorflow.style_images', entry_point_group='bob.learn.tensorflow.style_images',
help='List of images that encodes the style.') help='List of images that encodes the style.')
@click.option('--architecture', @click.option('--architecture',
...@@ -95,13 +95,21 @@ logger = logging.getLogger(__name__) ...@@ -95,13 +95,21 @@ logger = logging.getLogger(__name__)
cls=ResourceOption, cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.preprocess_fn', entry_point_group='bob.learn.tensorflow.preprocess_fn',
help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal') help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal')
@click.option(
'--start-from',
'-sf',
cls=ResourceOption,
default="noise",
type=click.Choice(["noise", "content", "style"]),
help="Starts from this image for reconstruction",
)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def style_transfer(content_image_path, output_path, style_image_paths, def style_transfer(content_image_path, output_path, style_image_paths,
architecture, checkpoint_dir, architecture, checkpoint_dir,
iterations, learning_rate, iterations, learning_rate,
content_weight, style_weight, denoise_weight, content_end_points, content_weight, style_weight, denoise_weight, content_end_points,
style_end_points, scopes, pure_noise, preprocess_fn, style_end_points, scopes, pure_noise, preprocess_fn,
un_preprocess_fn, **kwargs): un_preprocess_fn, start_from, **kwargs):
""" """
Trains neural style transfer using the approach presented in: Trains neural style transfer using the approach presented in:
...@@ -112,7 +120,7 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -112,7 +120,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
If you want run a style transfer using InceptionV2 as basis, use the following template If you want run a style transfer using InceptionV2 as basis, use the following template
Below follow a CONFIG template Below follow a CONFIG template
CONFIG.PY CONFIG.PY
``` ```
...@@ -159,7 +167,7 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -159,7 +167,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
"STYLE_2.png"] "STYLE_2.png"]
``` ```
Then run:: Then run::
$ bob tf style <content-image> <output-image> CONFIG.py $ bob tf style <content-image> <output-image> CONFIG.py
...@@ -178,14 +186,14 @@ def style_transfer(content_image_path, output_path, style_image_paths, ...@@ -178,14 +186,14 @@ def style_transfer(content_image_path, output_path, style_image_paths,
for path in style_image_paths: for path in style_image_paths:
style_images.append(bob.io.base.load(path)) style_images.append(bob.io.base.load(path))
output = do_style_transfer(content_image, style_images, output = do_style_transfer(content_image, style_images,
architecture, checkpoint_dir, scopes, architecture, checkpoint_dir, scopes,
content_end_points, style_end_points, content_end_points, style_end_points,
preprocess_fn=preprocess_fn, un_preprocess_fn=un_preprocess_fn, preprocess_fn=preprocess_fn, un_preprocess_fn=un_preprocess_fn,
pure_noise=pure_noise, pure_noise=pure_noise,
iterations=iterations, learning_rate=learning_rate, iterations=iterations, learning_rate=learning_rate,
content_weight=content_weight, style_weight=style_weight, content_weight=content_weight, style_weight=style_weight,
denoise_weight=denoise_weight) denoise_weight=denoise_weight, start_from=start_from)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
bob.io.base.save(output, output_path) bob.io.base.save(output, output_path)
...@@ -57,10 +57,7 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point ...@@ -57,10 +57,7 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
saver.restore(sess, checkpoint_dir) saver.restore(sess, checkpoint_dir)
#content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image}) #content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image})
features = [] features = sess.run([end_points[ep] for ep in target_end_points], feed_dict={input_pl: input_image})
for ep in target_end_points:
feature = sess.run(end_points[ep], feed_dict={input_pl: input_image})
features.append(feature)
# Killing the graph # Killing the graph
tf.reset_default_graph() tf.reset_default_graph()
...@@ -95,7 +92,7 @@ def do_style_transfer(content_image, style_images, ...@@ -95,7 +92,7 @@ def do_style_transfer(content_image, style_images,
content_end_points, style_end_points, content_end_points, style_end_points,
preprocess_fn=None, un_preprocess_fn=None, pure_noise=False, preprocess_fn=None, un_preprocess_fn=None, pure_noise=False,
iterations=1000, learning_rate=0.1, iterations=1000, learning_rate=0.1,
content_weight=5., style_weight=500., denoise_weight=500.): content_weight=5., style_weight=500., denoise_weight=500., start_from="noise"):
""" """
Trains neural style transfer using the approach presented in: Trains neural style transfer using the approach presented in:
...@@ -192,8 +189,16 @@ def do_style_transfer(content_image, style_images, ...@@ -192,8 +189,16 @@ def do_style_transfer(content_image, style_images,
tf.set_random_seed(0) tf.set_random_seed(0)
# Random noise # Random noise
noise = tf.Variable(tf.random_normal(shape=content_image.shape), if start_from == "noise":
trainable=True) * 0.256 starting_image = tf.random_normal(shape=content_image.shape) * 0.256
elif start_from == "content":
starting_image = preprocess_fn(content_image)
elif start_from == "style":
starting_image = preprocess_fn(style_images[0])
else:
raise ValueError(f"Unknown starting image: {start_from}")
noise = tf.Variable(starting_image, dtype="float32", trainable=True)
_, end_points = architecture(noise, _, end_points = architecture(noise,
mode=tf.estimator.ModeKeys.PREDICT, mode=tf.estimator.ModeKeys.PREDICT,
trainable_variables=[]) trainable_variables=[])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment