Commit 60cf788d authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Improved checkpoint discovery

parent 40a926d4
Pipeline #21926 failed with stage
in 19 minutes and 10 seconds
......@@ -15,7 +15,7 @@ import numpy
import bob.ip.base
import bob.ip.color
import sys
import os
from bob.learn.tensorflow.style_transfer import compute_features, compute_gram
from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, denoising_loss
......@@ -228,8 +228,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
solver = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir),
scopes)
tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir) if os.path.isdir(checkpoint_dir) else checkpoint_dir, scopes)
# Training
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
......@@ -245,11 +244,19 @@ def style_transfer(content_image_path, output_path, style_image_paths,
normalized_style_image = normalize4save(raw_style_image)
if pure_noise:
if normalized_style_image.shape[0] == 1:
bob.io.base.save(normalized_style_image[0, :, :], output_path)
else:
bob.io.base.save(normalized_style_image, output_path)
else:
# Original output
if normalized_style_image.shape[0] == 1:
normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(normalized_style_image[0,:,:]))
# Loading the content image and clipping from 0-255 in case is in another scale
scaled_content_image = normalize4save(bob.io.base.load(content_image_path).astype("float32")).astype("float64")
content_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(scaled_content_image))
else:
normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image)))
content_image_yuv = bob.ip.color.rgb_to_yuv(bob.io.base.load(content_image_path))
output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8")
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
import numpy
import os
def compute_features(input_image, architecture, checkpoint_dir, target_end_points):
"""
......@@ -26,9 +27,11 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
Dictionary containing the end point tensors
"""
input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1],
input_image.shape[2],
input_image.shape[3]))
# TODO: Think on how abstract this normalization operation
_, end_points = architecture(tf.stack([tf.image.per_image_standardization(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
......@@ -36,7 +39,11 @@ def compute_features(input_image, architecture, checkpoint_dir, target_end_point
# Restoring the checkpoint for the given architecture
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
else:
saver.restore(sess, checkpoint_dir)
#content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image})
features = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment