Commit 5d085e94 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added logger

parent 46d14f90
Pipeline #21862 passed with stage
in 21 minutes and 46 seconds
......@@ -22,11 +22,15 @@ from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, deno
logger = logging.getLogger(__name__)
def wise_shape(shape):
if len(shape)==2:
return (1, shape[0], shape[1], 1)
else:
return (1, shape[0], shape[1], shape[2])
def normalize4save(img):
return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8")
#@click.argument('style_image_path', required=True)
@click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
......@@ -168,26 +172,24 @@ def style_transfer(content_image_path, output_path, style_image_paths,
style_images = []
for path in style_image_paths:
image = bob.io.image.to_matplotlib(bob.io.base.load(path))
style_images.append(numpy.reshape(image, (1, image.shape[0],
image.shape[1],
image.shape[2])))
style_images.append(numpy.reshape(image, wise_shape(image.shape)))
# Reshaping to NxWxHxC
content_image = numpy.reshape(content_image, (1, content_image.shape[0],
content_image.shape[1],
content_image.shape[2]))
content_image = numpy.reshape(content_image, wise_shape(content_image.shape))
# Base content features
logger.info("Computing content features")
content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points)
# Base style features
# TODO: Enable a set of style images
logger.info("Computing style features")
style_grams = []
for image in style_images:
style_features = compute_features(image, architecture, checkpoint_dir, style_end_points)
style_grams.append(compute_gram(style_features))
# Organizing the trainer
logger.info("Training.....")
with tf.Graph().as_default():
tf.set_random_seed(0)
......@@ -234,7 +236,7 @@ def style_transfer(content_image_path, output_path, style_image_paths,
for i in range(iterations):
_, loss = sess.run([solver, total_loss])
print("Iteration {0}, loss {1}".format(i, loss))
logger.info("Iteration {0}, loss {1}".format(i, loss))
sys.stdout.flush()
# Saving generated image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment