Skip to content
Snippets Groups Projects
Commit cda150df authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

add starting point options to the style transfer script

parent 63993d46
Branches
Tags
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
......@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
@click.option('--style-image-paths',
cls=ResourceOption,
required=True,
multiple=True,
multiple=True,
entry_point_group='bob.learn.tensorflow.style_images',
help='List of images that encodes the style.')
@click.option('--architecture',
......@@ -95,13 +95,21 @@ logger = logging.getLogger(__name__)
cls=ResourceOption,
entry_point_group='bob.learn.tensorflow.preprocess_fn',
help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal')
@click.option(
'--start-from',
'-sf',
cls=ResourceOption,
default="noise",
type=click.Choice(["noise", "content", "style"]),
help="Starts from this image for reconstruction",
)
@verbosity_option(cls=ResourceOption)
def style_transfer(content_image_path, output_path, style_image_paths,
architecture, checkpoint_dir,
iterations, learning_rate,
content_weight, style_weight, denoise_weight, content_end_points,
style_end_points, scopes, pure_noise, preprocess_fn,
un_preprocess_fn, **kwargs):
style_end_points, scopes, pure_noise, preprocess_fn,
un_preprocess_fn, start_from, **kwargs):
"""
Trains neural style transfer using the approach presented in:
......@@ -112,7 +120,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
If you want run a style transfer using InceptionV2 as basis, use the following template
Below follow a CONFIG template
CONFIG.PY
```
......@@ -159,7 +167,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
"STYLE_2.png"]
```
Then run::
$ bob tf style <content-image> <output-image> CONFIG.py
......@@ -178,14 +186,14 @@ def style_transfer(content_image_path, output_path, style_image_paths,
for path in style_image_paths:
style_images.append(bob.io.base.load(path))
output = do_style_transfer(content_image, style_images,
output = do_style_transfer(content_image, style_images,
architecture, checkpoint_dir, scopes,
content_end_points, style_end_points,
preprocess_fn=preprocess_fn, un_preprocess_fn=un_preprocess_fn,
pure_noise=pure_noise,
iterations=iterations, learning_rate=learning_rate,
content_weight=content_weight, style_weight=style_weight,
denoise_weight=denoise_weight)
denoise_weight=denoise_weight, start_from=start_from)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
bob.io.base.save(output, output_path)
......@@ -57,10 +57,7 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
saver.restore(sess, checkpoint_dir)
#content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image})
features = []
for ep in target_end_points:
feature = sess.run(end_points[ep], feed_dict={input_pl: input_image})
features.append(feature)
features = sess.run([end_points[ep] for ep in target_end_points], feed_dict={input_pl: input_image})
# Killing the graph
tf.reset_default_graph()
......@@ -95,7 +92,7 @@ def do_style_transfer(content_image, style_images,
content_end_points, style_end_points,
preprocess_fn=None, un_preprocess_fn=None, pure_noise=False,
iterations=1000, learning_rate=0.1,
content_weight=5., style_weight=500., denoise_weight=500.):
content_weight=5., style_weight=500., denoise_weight=500., start_from="noise"):
"""
Trains neural style transfer using the approach presented in:
......@@ -192,8 +189,16 @@ def do_style_transfer(content_image, style_images,
tf.set_random_seed(0)
# Random noise
noise = tf.Variable(tf.random_normal(shape=content_image.shape),
trainable=True) * 0.256
if start_from == "noise":
starting_image = tf.random_normal(shape=content_image.shape) * 0.256
elif start_from == "content":
starting_image = preprocess_fn(content_image)
elif start_from == "style":
starting_image = preprocess_fn(style_images[0])
else:
raise ValueError(f"Unknown starting image: {start_from}")
noise = tf.Variable(starting_image, dtype="float32", trainable=True)
_, end_points = architecture(noise,
mode=tf.estimator.ModeKeys.PREDICT,
trainable_variables=[])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment