diff --git a/bob/learn/tensorflow/configs/style_transfer/inception_v1_example.py b/bob/learn/tensorflow/configs/style_transfer/inception_v1_example.py new file mode 100644 index 0000000000000000000000000000000000000000..00bfa1e0628534802d47bba19ad27b20b04ef236 --- /dev/null +++ b/bob/learn/tensorflow/configs/style_transfer/inception_v1_example.py @@ -0,0 +1,27 @@ +""" +Example using inception resnet v1 + +""" + +import tensorflow as tf + +# -- architecture +from bob.learn.tensorflow.network import inception_resnet_v1_batch_norm +architecture = inception_resnet_v1_batch_norm + +# --checkpoint-dir +from bob.extension import rc +checkpoint_dir = rc['bob.bio.face_ongoing.inception-v1_batchnorm_rgb'] + +# --style-end-points and -- content-end-points +style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3"] +content_end_points = ["Block8"] + +scopes = {"InceptionResnetV1/":"InceptionResnetV1/"} + +# --style-image-paths +style_image_paths = ["vincent_van_gogh.jpg", + "vincent_van_gogh2.jpg"] + +# --preprocess-fn +preprocess_fn = tf.image.per_image_standardization diff --git a/bob/learn/tensorflow/configs/style_transfer/inception_v2_example.py b/bob/learn/tensorflow/configs/style_transfer/inception_v2_example.py new file mode 100644 index 0000000000000000000000000000000000000000..54eb3f2931f65e78fb6a1e201b65d119f172d0be --- /dev/null +++ b/bob/learn/tensorflow/configs/style_transfer/inception_v2_example.py @@ -0,0 +1,26 @@ +""" +Example using inception resnet v2 +""" + +import tensorflow as tf + +# -- architecture +from bob.learn.tensorflow.network import inception_resnet_v2_batch_norm +architecture = inception_resnet_v2_batch_norm + +# --checkpoint-dir +from bob.extension import rc +checkpoint_dir = rc['bob.bio.face_ongoing.inception-v2_batchnorm_rgb'] + +# --style-end-points and -- content-end-points +style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3"] +content_end_points = ["Block8"] + +scopes = {"InceptionResnetV2/":"InceptionResnetV2/"} + +# --style-image-paths +style_image_paths = ["vincent_van_gogh.jpg", + "vincent_van_gogh2.jpg"] + +# --preprocess-fn +preprocess_fn = tf.image.per_image_standardization diff --git a/bob/learn/tensorflow/configs/style_transfer/vgg19_example.py b/bob/learn/tensorflow/configs/style_transfer/vgg19_example.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f2434c418cfddd34b6f14636bdd3dd862d3544 --- /dev/null +++ b/bob/learn/tensorflow/configs/style_transfer/vgg19_example.py @@ -0,0 +1,44 @@ +""" +Example using VGG19 +""" + +from bob.learn.tensorflow.network import vgg_19 +# --architecture +architecture = vgg_19 + + +import numpy + +# -- checkpoint-dir +# YOU CAN DOWNLOAD THE CHECKPOINTS FROM HERE +# https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models +checkpoint_dir = "[VGG-19-CHECKPOINT]" + +# --style-end-points and -- content-end-points +content_end_points = ['vgg_19/conv4/conv4_2', 'vgg_19/conv5/conv5_2'] +style_end_points = ['vgg_19/conv1/conv1_2', + 'vgg_19/conv2/conv2_1', + 'vgg_19/conv3/conv3_1', + 'vgg_19/conv4/conv4_1', + 'vgg_19/conv5/conv5_1' + ] + + +scopes = {"vgg_19/":"vgg_19/"} + +style_image_paths = ["vincent_van_gogh.jpg", + "vincent_van_gogh2.jpg"] + + +# --preprocess-fn and --un-preprocess-fn +# Taken from VGG19 +def mean_norm(tensor): + return tensor - numpy.array([ 123.68 , 116.779, 103.939]) + +def un_mean_norm(tensor): + return tensor + numpy.array([ 123.68 , 116.779, 103.939]) + +preprocess_fn = mean_norm + +un_preprocess_fn = un_mean_norm + diff --git a/bob/learn/tensorflow/configs/style_transfer/vincent_van_gogh.jpg b/bob/learn/tensorflow/configs/style_transfer/vincent_van_gogh.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c17534eca0469abfa043146da04f82eab014ba99 Binary files /dev/null and b/bob/learn/tensorflow/configs/style_transfer/vincent_van_gogh.jpg differ diff --git a/bob/learn/tensorflow/configs/style_transfer/vincent_van_gogh2.jpg b/bob/learn/tensorflow/configs/style_transfer/vincent_van_gogh2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9a74d8588e73b044c3f8a51e8bb3bcf1e950174a Binary files /dev/null and b/bob/learn/tensorflow/configs/style_transfer/vincent_van_gogh2.jpg differ diff --git a/bob/learn/tensorflow/loss/StyleLoss.py b/bob/learn/tensorflow/loss/StyleLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..cc132d4cbac291d7504eb6e2d990238ba9b78ee1 --- /dev/null +++ b/bob/learn/tensorflow/loss/StyleLoss.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + +import logging +import tensorflow as tf +logger = logging.getLogger("bob.learn.tensorflow") + + +def content_loss(noises, content_features): + """ + + Implements the content loss from: + + Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015). + + For a given noise signal :math:`n`, content image :math:`c` and convolved with the DCNN :math:`\phi` until the layer :math:`l` the content loss is defined as: + + :math:`L(n,c) = \sum_{l=?}^{?}({\phi^l(n) - \phi^l(c)})^2` + + + Parameters + ---------- + + noises: list + A list of tf.Tensor containing all the noises convolved + + content_features: list + A list of numpy.array containing all the content_features convolved + + """ + + content_losses = [] + for n,c in zip(noises, content_features): + content_losses.append((2 * tf.nn.l2_loss(n - c) / c.size)) + return reduce(tf.add, content_losses) + + +def linear_gram_style_loss(noises, gram_style_features): + """ + + Implements the style loss from: + + Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015). + + For a given noise signal :math:`n`, content image :math:`c` and convolved with the DCNN :math:`\phi` until the layer :math:`l` the STYLE loss is defined as: + + :math:`L(n,c) = \sum_{l=?}^{?}\frac{({\phi^l(n)^T*\phi^l(n) - \phi^l(c)^T*\phi^l(c)})^2}{N*M}` + + + Parameters + ---------- + + noises: list + A list of tf.Tensor containing all the noises convolved + + gram_style_features: list + A list of numpy.array containing all the content_features convolved + + """ + + style_losses = [] + for n,s in zip(noises, gram_style_features): + style_losses.append((2 * tf.nn.l2_loss(n - s)) / s.size) + + return reduce(tf.add, style_losses) + + + +def denoising_loss(noise): + """ + Computes the denoising loss as in: + + Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015). + + Parameters + ---------- + + noise: tf.Tensor + Input noise + + """ + def _tensor_size(tensor): + from operator import mul + return reduce(mul, (d.value for d in tensor.get_shape()), 1) + + shape = noise.get_shape().as_list() + + noise_y_size = _tensor_size(noise[:,1:,:,:]) + noise_x_size = _tensor_size(noise[:,:,1:,:]) + denoise_loss = 2 * ( (tf.nn.l2_loss(noise[:,1:,:,:] - noise[:,:shape[1]-1,:,:]) / noise_y_size) + + (tf.nn.l2_loss(noise[:,:,1:,:] - noise[:,:,:shape[2]-1,:]) / noise_x_size)) + + return denoise_loss + diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py index 379b180f784969518958e2864f27c001ff21f17c..17947ea14336862a00238021238975d58ec84895 100644 --- a/bob/learn/tensorflow/loss/__init__.py +++ b/bob/learn/tensorflow/loss/__init__.py @@ -1,6 +1,7 @@ from .BaseLoss import mean_cross_entropy_loss, mean_cross_entropy_center_loss from .ContrastiveLoss import contrastive_loss from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss +from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss # gets sphinx autodoc done right - don't remove it diff --git a/bob/learn/tensorflow/network/InceptionResnetV1.py b/bob/learn/tensorflow/network/InceptionResnetV1.py index 5e7d37609587b1be90cd2bd890b0eaa150b88951..6d84580b89755cd3769b639fbae86c916aad60c9 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV1.py +++ b/bob/learn/tensorflow/network/InceptionResnetV1.py @@ -303,7 +303,7 @@ def inception_resnet_v1_batch_norm(inputs, # force in-place updates of mean and variance estimates 'updates_collections': None, # Moving averages ends up in the trainable variables collection - 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], + 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES if mode==tf.estimator.ModeKeys.TRAIN else None], } with slim.arg_scope( @@ -363,7 +363,7 @@ def inception_resnet_v1(inputs, with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse): with slim.arg_scope( - [slim.batch_norm, slim.dropout], + [slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): with slim.arg_scope( @@ -373,37 +373,53 @@ def inception_resnet_v1(inputs, # 149 x 149 x 32 name = "Conv2d_1a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - inputs, - 32, - 3, - stride=2, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + inputs, + 32, + 3, + stride=2, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 147 x 147 x 32 name = "Conv2d_2a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 32, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, + 32, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 147 x 147 x 64 name = "Conv2d_2b_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 64, 3, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, 64, 3, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net + # 73 x 73 x 64 net = slim.max_pool2d( net, 3, stride=2, padding='VALID', scope='MaxPool_3a_3x3') @@ -411,110 +427,148 @@ def inception_resnet_v1(inputs, # 73 x 73 x 80 name = "Conv2d_3b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 80, - 1, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.conv2d( + net, + 80, + 1, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 71 x 71 x 192 name = "Conv2d_4a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 192, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.conv2d( + net, + 192, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 35 x 35 x 256 name = "Conv2d_4b_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 256, - 3, - stride=2, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.conv2d( + net, + 256, + 3, + stride=2, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 5 x Inception-resnet-A name = "block35" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 5, - block35, - scale=0.17, - trainable_variables=trainable, - reuse=reuse) - end_points[name] = net - - # Reduction-A - name = "Mixed_6a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - net = reduction_a( + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.repeat( net, - 192, - 192, - 256, - 384, + 5, + block35, + scale=0.17, trainable_variables=trainable, reuse=reuse) - end_points[name] = net + end_points[name] = net + + # Reduction-A + name = "Mixed_6a" + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + with tf.variable_scope(name): + net = reduction_a( + net, + 192, + 192, + 256, + 384, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net # 10 x Inception-Resnet-B name = "block17" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 10, - block17, - scale=0.10, - trainable_variables=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.repeat( + net, + 10, + block17, + scale=0.10, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net # Reduction-B name = "Mixed_7a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - net = reduction_b( - net, trainable_variables=trainable, reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + with tf.variable_scope(name): + net = reduction_b( + net, trainable_variables=trainable, reuse=reuse) + end_points[name] = net # 5 x Inception-Resnet-C name = "block8" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 5, - block8, - scale=0.20, - trainable_variables=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.repeat( + net, + 5, + block8, + scale=0.20, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net name = "Mixed_8b" - trainable = is_trainable(name, trainable_variables) - net = block8( - net, - activation_fn=None, - trainable_variables=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = block8( + net, + activation_fn=None, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net + with tf.variable_scope('Logits'): end_points['PrePool'] = net @@ -535,13 +589,18 @@ def inception_resnet_v1(inputs, end_points['PreLogitsFlatten'] = net name = "Bottleneck" - trainable = is_trainable(name, trainable_variables) - net = slim.fully_connected( - net, - bottleneck_layer_size, - activation_fn=None, - scope=name, - reuse=reuse, - trainable=trainable) + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + net = slim.fully_connected( + net, + bottleneck_layer_size, + activation_fn=None, + scope=name, + reuse=reuse, + trainable=trainable) + end_points[name] = net return net, end_points diff --git a/bob/learn/tensorflow/network/InceptionResnetV2.py b/bob/learn/tensorflow/network/InceptionResnetV2.py index fcdb4c13768055f0e4c0b6a33b528259a8734ead..8b3584794a9ac62d497a5a086673d2901f744624 100644 --- a/bob/learn/tensorflow/network/InceptionResnetV2.py +++ b/bob/learn/tensorflow/network/InceptionResnetV2.py @@ -249,8 +249,10 @@ def inception_resnet_v2_batch_norm(inputs, # force in-place updates of mean and variance estimates 'updates_collections': None, # Moving averages ends up in the trainable variables collection - 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES], + 'variables_collections': [tf.GraphKeys.TRAINABLE_VARIABLES if mode==tf.estimator.ModeKeys.TRAIN else None], } + + weight_decay = 5e-5 with slim.arg_scope( [slim.conv2d, slim.fully_connected], @@ -305,8 +307,7 @@ def inception_resnet_v2(inputs, end_points = {} with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse): - with slim.arg_scope( - [slim.batch_norm, slim.dropout], + with slim.arg_scope([slim.dropout], is_training=(mode == tf.estimator.ModeKeys.TRAIN)): with slim.arg_scope( @@ -315,37 +316,52 @@ def inception_resnet_v2(inputs, padding='SAME'): # 149 x 149 x 32 name = "Conv2d_1a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - inputs, - 32, - 3, - stride=2, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + inputs, + 32, + 3, + stride=2, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 147 x 147 x 32 name = "Conv2d_2a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 32, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, + 32, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 147 x 147 x 64 name = "Conv2d_2b_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 64, 3, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, 64, 3, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net # 73 x 73 x 64 net = slim.max_pool2d( @@ -354,29 +370,39 @@ def inception_resnet_v2(inputs, # 73 x 73 x 80 name = "Conv2d_3b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 80, - 1, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, + 80, + 1, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 71 x 71 x 192 name = "Conv2d_4a_3x3" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, - 192, - 3, - padding='VALID', - scope=name, - trainable=trainable, - reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, + 192, + 3, + padding='VALID', + scope=name, + trainable=trainable, + reuse=reuse) + end_points[name] = net # 35 x 35 x 192 net = slim.max_pool2d( @@ -385,237 +411,274 @@ def inception_resnet_v2(inputs, # 35 x 35 x 320 name = "Mixed_5b" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 96, - 1, - scope='Conv2d_1x1', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d( - net, - 48, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1_0, - 64, - 5, - scope='Conv2d_0b_5x5', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_conv2_0 = slim.conv2d( - net, - 64, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv2_1 = slim.conv2d( - tower_conv2_0, - 96, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv2_2 = slim.conv2d( - tower_conv2_1, - 96, - 3, - scope='Conv2d_0c_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_3'): - tower_pool = slim.avg_pool2d( - net, - 3, - stride=1, - padding='SAME', - scope='AvgPool_0a_3x3') - tower_pool_1 = slim.conv2d( - tower_pool, - 64, - 1, - scope='Conv2d_0b_1x1', - trainable=trainable, - reuse=reuse) - net = tf.concat([ - tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1 - ], 3) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 96, + 1, + scope='Conv2d_1x1', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d( + net, + 48, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1_0, + 64, + 5, + scope='Conv2d_0b_5x5', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_conv2_0 = slim.conv2d( + net, + 64, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv2_1 = slim.conv2d( + tower_conv2_0, + 96, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv2_2 = slim.conv2d( + tower_conv2_1, + 96, + 3, + scope='Conv2d_0c_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_3'): + tower_pool = slim.avg_pool2d( + net, + 3, + stride=1, + padding='SAME', + scope='AvgPool_0a_3x3') + tower_pool_1 = slim.conv2d( + tower_pool, + 64, + 1, + scope='Conv2d_0b_1x1', + trainable=trainable, + reuse=reuse) + net = tf.concat([ + tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1 + ], 3) + end_points[name] = net # BLOCK 35 name = "Block35" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 10, - block35, - scale=0.17, - trainable_variables=trainable, - reuse=reuse) + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.repeat( + net, + 10, + block35, + scale=0.17, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net # 17 x 17 x 1024 name = "Mixed_6a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1_0 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1_0, - 256, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv1_2 = slim.conv2d( - tower_conv1_1, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_pool = slim.max_pool2d( - net, - 3, - stride=2, - padding='VALID', - scope='MaxPool_1a_3x3') - net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) - - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1_0 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1_0, + 256, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv1_2 = slim.conv2d( + tower_conv1_1, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_pool = slim.max_pool2d( + net, + 3, + stride=2, + padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3) + end_points[name] = net # BLOCK 17 name = "Block17" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 20, - block17, - scale=0.10, - trainable_variables=trainable, - reuse=reuse) + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.repeat( + net, + 20, + block17, + scale=0.10, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net name = "Mixed_7a" - trainable = is_trainable(name, trainable_variables) - with tf.variable_scope(name): - with tf.variable_scope('Branch_0'): - tower_conv = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv_1 = slim.conv2d( - tower_conv, - 384, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_1'): - tower_conv1 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv1_1 = slim.conv2d( - tower_conv1, - 288, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_2'): - tower_conv2 = slim.conv2d( - net, - 256, - 1, - scope='Conv2d_0a_1x1', - trainable=trainable, - reuse=reuse) - tower_conv2_1 = slim.conv2d( - tower_conv2, - 288, - 3, - scope='Conv2d_0b_3x3', - trainable=trainable, - reuse=reuse) - tower_conv2_2 = slim.conv2d( - tower_conv2_1, - 320, - 3, - stride=2, - padding='VALID', - scope='Conv2d_1a_3x3', - trainable=trainable, - reuse=reuse) - with tf.variable_scope('Branch_3'): - tower_pool = slim.max_pool2d( - net, - 3, - stride=2, - padding='VALID', - scope='MaxPool_1a_3x3') - net = tf.concat([ - tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool - ], 3) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + with tf.variable_scope(name): + with tf.variable_scope('Branch_0'): + tower_conv = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv_1 = slim.conv2d( + tower_conv, + 384, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_1'): + tower_conv1 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv1_1 = slim.conv2d( + tower_conv1, + 288, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_2'): + tower_conv2 = slim.conv2d( + net, + 256, + 1, + scope='Conv2d_0a_1x1', + trainable=trainable, + reuse=reuse) + tower_conv2_1 = slim.conv2d( + tower_conv2, + 288, + 3, + scope='Conv2d_0b_3x3', + trainable=trainable, + reuse=reuse) + tower_conv2_2 = slim.conv2d( + tower_conv2_1, + 320, + 3, + stride=2, + padding='VALID', + scope='Conv2d_1a_3x3', + trainable=trainable, + reuse=reuse) + with tf.variable_scope('Branch_3'): + tower_pool = slim.max_pool2d( + net, + 3, + stride=2, + padding='VALID', + scope='MaxPool_1a_3x3') + net = tf.concat([ + tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool + ], 3) + end_points[name] = net # Block 8 name = "Block8" - trainable = is_trainable(name, trainable_variables) - net = slim.repeat( - net, - 9, - block8, - scale=0.20, - trainable_variables=trainable, - reuse=reuse) - net = block8( - net, - activation_fn=None, - trainable_variables=trainable, - reuse=reuse) + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.repeat( + net, + 9, + block8, + scale=0.20, + trainable_variables=trainable, + reuse=reuse) + net = block8( + net, + activation_fn=None, + trainable_variables=trainable, + reuse=reuse) + end_points[name] = net name = "Conv2d_7b_1x1" - trainable = is_trainable(name, trainable_variables) - net = slim.conv2d( - net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.conv2d( + net, 1536, 1, scope=name, trainable=trainable, reuse=reuse) + end_points[name] = net with tf.variable_scope('Logits'): end_points['PrePool'] = net @@ -632,14 +695,19 @@ def inception_resnet_v2(inputs, end_points['PreLogitsFlatten'] = net name = "Bottleneck" - trainable = is_trainable(name, trainable_variables) - net = slim.fully_connected( - net, - bottleneck_layer_size, - activation_fn=None, - scope=name, - reuse=reuse, - trainable=trainable) - end_points[name] = net + trainable = is_trainable(name, trainable_variables, mode=mode) + with slim.arg_scope( + [slim.batch_norm], + is_training=(mode == tf.estimator.ModeKeys.TRAIN), + trainable = trainable): + + net = slim.fully_connected( + net, + bottleneck_layer_size, + activation_fn=None, + scope=name, + reuse=reuse, + trainable=trainable) + end_points[name] = net return net, end_points diff --git a/bob/learn/tensorflow/network/Vgg.py b/bob/learn/tensorflow/network/Vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..f863fd4b92e7f31b3436ac418c32a6362d3f1412 --- /dev/null +++ b/bob/learn/tensorflow/network/Vgg.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + + +""" +VGG16 and VGG19 wrappers +""" + +import tensorflow as tf +from tensorflow.contrib.slim.python.slim.nets import vgg +import tensorflow.contrib.slim as slim + + +def vgg_19(inputs, + reuse=None, + mode=tf.estimator.ModeKeys.TRAIN, **kwargs): + """ + Oxford Net VGG 19-Layers version E Example from tf-slim + + https://raw.githubusercontent.com/tensorflow/models/master/research/slim/nets/vgg.py + + **Parameters**: + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + mode: + Estimator mode keys + """ + + with slim.arg_scope( + [slim.conv2d], + trainable=mode==tf.estimator.ModeKeys.TRAIN): + + return vgg.vgg_19(inputs, spatial_squeeze=False) + + +def vgg_16(inputs, + reuse=None, + mode=tf.estimator.ModeKeys.TRAIN, **kwargs): + """ + Oxford Net VGG 16-Layers version E Example from tf-slim + + https://raw.githubusercontent.com/tensorflow/models/master/research/slim/nets/vgg.py + + **Parameters**: + + inputs: a 4-D tensor of size [batch_size, height, width, 3]. + + reuse: whether or not the network and its variables should be reused. To be + able to reuse 'scope' must be given. + + mode: + Estimator mode keys + """ + + with slim.arg_scope( + [slim.conv2d], + trainable=mode==tf.estimator.ModeKeys.TRAIN): + + return vgg.vgg_16(inputs, spatial_squeeze=False) + diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py index 9e29be5b3815b7d073f2b87ee5ee0049daff83a2..e6418da06fc1fcb33654c9c5f79f3ee9c0d517de 100644 --- a/bob/learn/tensorflow/network/__init__.py +++ b/bob/learn/tensorflow/network/__init__.py @@ -5,6 +5,7 @@ from .MLP import mlp from .InceptionResnetV2 import inception_resnet_v2, inception_resnet_v2_batch_norm from .InceptionResnetV1 import inception_resnet_v1, inception_resnet_v1_batch_norm from . import SimpleCNN +from .Vgg import vgg_19, vgg_16 # gets sphinx autodoc done right - don't remove it diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py index 852db26727592b00762d1d8984190602f9716fc8..df4ef0cb09d44443faed1a17712906b86c1c48ef 100644 --- a/bob/learn/tensorflow/network/utils.py +++ b/bob/learn/tensorflow/network/utils.py @@ -22,7 +22,7 @@ def append_logits(graph, reuse=reuse) -def is_trainable(name, trainable_variables): +def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN): """ Check if a variable is trainable or not @@ -37,9 +37,14 @@ def is_trainable(name, trainable_variables): If None, the variable/scope is trained """ + # if mode is not training, so we shutdown + if mode != tf.estimator.ModeKeys.TRAIN: + return False + # If None, we train by default if trainable_variables is None: return True # Here is my choice to shutdown the whole scope return name in trainable_variables + diff --git a/bob/learn/tensorflow/script/style_transfer.py b/bob/learn/tensorflow/script/style_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..a11148318143305cd76eb09e89ab1fcb371d1821 --- /dev/null +++ b/bob/learn/tensorflow/script/style_transfer.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python +"""Trains networks using Tensorflow estimators. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import logging +import click +import tensorflow as tf +from bob.extension.scripts.click_helper import (verbosity_option, + ConfigCommand, ResourceOption) +import bob.io.image +import bob.io.base +import numpy +import bob.ip.base +import bob.ip.color +import sys +import os +from bob.learn.tensorflow.style_transfer import compute_features, compute_gram +from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, denoising_loss + + +logger = logging.getLogger(__name__) + +def wise_shape(shape): + if len(shape)==2: + return (1, shape[0], shape[1], 1) + else: + return (1, shape[0], shape[1], shape[2]) + +def normalize4save(img): + return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8") + + +@click.command( + entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) +@click.argument('content_image_path', required=True) +@click.argument('output_path', required=True) +@click.option('--style-image-paths', + cls=ResourceOption, + required=True, + multiple=True, + entry_point_group='bob.learn.tensorflow.style_images', + help='List of images that encods the style.') +@click.option('--architecture', + '-a', + required=True, + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.architecture', + help='The base architecure.') +@click.option('--checkpoint-dir', + '-c', + required=True, + cls=ResourceOption, + help='The base architecure.') +@click.option('--iterations', + '-i', + type=click.types.INT, + help='Number of steps for which to train model.', + default=1000) +@click.option('--learning_rate', + '-i', + type=click.types.FLOAT, + help='Learning rate.', + default=1.) +@click.option('--content-weight', + type=click.types.FLOAT, + help='Weight of the content loss.', + default=5.) +@click.option('--style-weight', + type=click.types.FLOAT, + help='Weight of the style loss.', + default=100.) +@click.option('--denoise-weight', + type=click.types.FLOAT, + help='Weight denoising loss.', + default=100.) +@click.option('--content-end-points', + cls=ResourceOption, + multiple=True, + entry_point_group='bob.learn.tensorflow.end_points', + help='List of end_points for the used to encode the content') +@click.option('--style-end-points', + cls=ResourceOption, + multiple=True, + entry_point_group='bob.learn.tensorflow.end_points', + help='List of end_points for the used to encode the style') +@click.option('--scopes', + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.scopes', + help='Dictionary containing the mapping scores', + required=True) +@click.option('--pure-noise', + is_flag=True, + help="If set will save the raw noisy generated image." + "If not set, the output will be RGB = stylizedYUV.Y, originalYUV.U, originalYUV.V" + ) +@click.option('--preprocess-fn', + '-pr', + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.preprocess_fn', + help='Preprocess function. Pointer to a function that preprocess the INPUT signal') +@click.option('--un-preprocess-fn', + '-un', + cls=ResourceOption, + entry_point_group='bob.learn.tensorflow.preprocess_fn', + help='Un preprocess function. Pointer to a function that preprocess the OUTPUT signal') +@verbosity_option(cls=ResourceOption) +def style_transfer(content_image_path, output_path, style_image_paths, + architecture, checkpoint_dir, + iterations, learning_rate, + content_weight, style_weight, denoise_weight, content_end_points, + style_end_points, scopes, pure_noise, preprocess_fn, + un_preprocess_fn, **kwargs): + """ + Trains neural style transfer using the approach presented in: + + Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015). + + \b + + If you want run a style transfer using InceptionV2 as basis folo + + Below follow a CONFIG template + + CONFIG.PY + ``` + + from bob.extension import rc + + from bob.learn.tensorflow.network import inception_resnet_v2_batch_norm + architecture = inception_resnet_v2_batch_norm + + checkpoint_dir = rc["bob.bio.face_ongoing.idiap_casia_inception_v2_centerloss_rgb"] + + style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3", "Conv2d_3b_1x1", "Conv2d_4a_3x3"] + + content_end_points = ["Bottleneck", "PreLogitsFlatten"] + + scopes = {"InceptionResnetV2/":"InceptionResnetV2/"} + + ``` + \b + + Then run:: + + $ bob tf style <content-image> <output-image> --style-image-paths <style-image> CONFIG.py + + + You can also provide a list of images to encode the style using the config file as in the example below. + + CONFIG.PY + ``` + + from bob.extension import rc + + from bob.learn.tensorflow.network import inception_resnet_v2_batch_norm + architecture = inception_resnet_v2_batch_norm + + checkpoint_dir = rc["bob.bio.face_ongoing.idiap_casia_inception_v2_centerloss_rgb"] + + style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3", "Conv2d_3b_1x1", "Conv2d_4a_3x3"] + + content_end_points = ["Bottleneck", "PreLogitsFlatten"] + + scopes = {"InceptionResnetV2/":"InceptionResnetV2/"} + + style_image_paths = ["STYLE_1.png", + "STYLE_2.png"] + + ``` + + Then run:: + + $ bob tf style <content-image> <output-image> CONFIG.py + + \b \b + + """ + + # Reading and converting to the tensorflow format + content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image_path)).astype("float32") + style_images = [] + for path in style_image_paths: + image = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype("float32") + style_images.append(numpy.reshape(image, wise_shape(image.shape))) + + # Reshaping to NxWxHxC + content_image = numpy.reshape(content_image, wise_shape(content_image.shape)) + + # Base content features + logger.info("Computing content features") + content_features = compute_features(content_image, architecture, checkpoint_dir, + content_end_points, preprocess_fn) + + # Base style features + logger.info("Computing style features") + style_grams = [] + for image in style_images: + style_features = compute_features(image, architecture, checkpoint_dir, + style_end_points, preprocess_fn) + style_grams.append(compute_gram(style_features)) + + # Organizing the trainer + logger.info("Training.....") + with tf.Graph().as_default(): + tf.set_random_seed(0) + + # Random noise + noise = tf.Variable(tf.random_normal(shape=content_image.shape), + trainable=True) * 0.256 + _, end_points = architecture(noise, + mode=tf.estimator.ModeKeys.PREDICT, + trainable_variables=[]) + + # Computing content loss + content_noises = [] + for c in content_end_points: + content_noises.append(end_points[c]) + c_loss = content_loss(content_noises, content_features) + + # Computing style_loss + style_gram_noises = [] + s_loss = 0 + for grams_per_image in style_grams: + + for c in style_end_points: + layer = end_points[c] + _, height, width, number = map(lambda i: i.value, layer.get_shape()) + size = height * width * number + features = tf.reshape(layer, (-1, number)) + style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size) + s_loss += linear_gram_style_loss(style_gram_noises, grams_per_image) + + # Variation denoise + d_loss = denoising_loss(noise) + + #Total loss + total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss + + solver = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) + + tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir) if os.path.isdir(checkpoint_dir) else checkpoint_dir, scopes) + # Training + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + for i in range(iterations): + _, loss = sess.run([solver, total_loss]) + logger.info("Iteration {0}, loss {1}".format(i, loss)) + sys.stdout.flush() + + # Saving generated image + raw_style_image = sess.run(noise)[0, :, :,:] + # Unpreprocessing the signal + if un_preprocess_fn is not None: + raw_style_image = un_preprocess_fn(raw_style_image) + + raw_style_image = bob.io.image.to_bob(raw_style_image) + normalized_style_image = normalize4save(raw_style_image) + + if pure_noise: + if normalized_style_image.shape[0] == 1: + bob.io.base.save(normalized_style_image[0, :, :], output_path) + else: + bob.io.base.save(normalized_style_image, output_path) + else: + # Original output + if normalized_style_image.shape[0] == 1: + normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(normalized_style_image[0,:,:])) + # Loading the content image and clipping from 0-255 in case is in another scale + scaled_content_image = normalize4save(bob.io.base.load(content_image_path).astype("float32")).astype("float64") + content_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(scaled_content_image)) + else: + normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image))) + content_image_yuv = bob.ip.color.rgb_to_yuv(bob.io.base.load(content_image_path)) + + output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8") + output_image[0,:,:] = normalized_style_image_yuv[0,:,:] + output_image[1,:,:] = content_image_yuv[1,:,:] + output_image[2,:,:] = content_image_yuv[2,:,:] + + output_image = bob.ip.color.yuv_to_rgb(output_image) + bob.io.base.save(output_image, output_path) + diff --git a/bob/learn/tensorflow/style_transfer/__init__.py b/bob/learn/tensorflow/style_transfer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6720548f70e2c68b58bf55c297b67b1dd61a45 --- /dev/null +++ b/bob/learn/tensorflow/style_transfer/__init__.py @@ -0,0 +1,22 @@ +from .neural_transfer import compute_features, compute_gram + +# gets sphinx autodoc done right - don't remove it +def __appropriate__(*args): + """Says object was actually declared here, an not on the import module. + + Parameters: + + *args: An iterable of objects to modify + + Resolves `Sphinx referencing issues + <https://github.com/sphinx-doc/sphinx/issues/3048>` + """ + + for obj in args: + obj.__module__ = __name__ + + +__appropriate__( +) + +__all__ = [_ for _ in dir() if not _.startswith('_')] diff --git a/bob/learn/tensorflow/style_transfer/neural_transfer.py b/bob/learn/tensorflow/style_transfer/neural_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..994bf78e54df84ec09803df013a02d300c6c00db --- /dev/null +++ b/bob/learn/tensorflow/style_transfer/neural_transfer.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> + + +import tensorflow as tf +import numpy +import os + +def compute_features(input_image, architecture, checkpoint_dir, target_end_points, preprocess_fn=None): + """ + For a given set of end_points, convolve the input image until these points + + Parameters + ---------- + + input_image: numpy.array + Input image in the format WxHxC + + architecture: + Pointer to the architecture function + + checkpoint_dir: str + DCNN checkpoint directory + + end_points: dict + Dictionary containing the end point tensors + + preprocess_fn: + Pointer to a preprocess function + + """ + + input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1], + input_image.shape[2], + input_image.shape[3])) + + if preprocess_fn is None: + _, end_points = architecture(input_pl, mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None) + else: + _, end_points = architecture(tf.stack([preprocess_fn(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None) + with tf.Session() as sess: + # Restoring the checkpoint for the given architecture + sess.run(tf.global_variables_initializer()) + saver = tf.train.Saver() + + if os.path.isdir(checkpoint_dir): + saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) + else: + saver.restore(sess, checkpoint_dir) + + #content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image}) + features = [] + for ep in target_end_points: + feature = sess.run(end_points[ep], feed_dict={input_pl: input_image}) + features.append(feature) + + # Killing the graph + tf.reset_default_graph() + return features + + +def compute_gram(features): + """ + Given a list of features (as numpy.arrays) comput the gram matrices of each + pinning the channel as in: + + Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015). + + Parameters + ---------- + + features: numpy.array + Convolved features in the format NxWxHxC + + """ + + grams = [] + for f in features: + f = numpy.reshape(f, (-1, f.shape[3])) + grams.append(numpy.matmul(f.T, f) / f.size) + + return grams + diff --git a/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0_GRAY.png b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0_GRAY.png new file mode 100644 index 0000000000000000000000000000000000000000..e7de9b7d4b792351e2724ada32bd88d9dc5d3ff0 Binary files /dev/null and b/bob/learn/tensorflow/test/data/dummy_image_database/m301_01_p01_i0_0_GRAY.png differ diff --git a/bob/learn/tensorflow/test/test_architectures.py b/bob/learn/tensorflow/test/test_architectures.py index 85cd8310c2b89455d458ea4e77e6af9e3befae1a..bb17997d360e869967b83b5b61eff2f44a257e71 100644 --- a/bob/learn/tensorflow/test/test_architectures.py +++ b/bob/learn/tensorflow/test/test_architectures.py @@ -4,7 +4,8 @@ import tensorflow as tf from bob.learn.tensorflow.network import inception_resnet_v2, inception_resnet_v2_batch_norm,\ - inception_resnet_v1, inception_resnet_v1_batch_norm + inception_resnet_v1, inception_resnet_v1_batch_norm,\ + vgg_19, vgg_16 def test_inceptionv2(): @@ -43,3 +44,41 @@ def test_inceptionv1(): tf.reset_default_graph() assert len(tf.global_variables()) == 0 + + +def test_vgg(): + # Testing VGG19 Training mode + inputs = tf.placeholder(tf.float32, shape=(1, 224, 224, 3)) + graph, _ = vgg_19(inputs) + assert len(tf.trainable_variables()) == 38 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + + # Testing VGG19 predicting mode + inputs = tf.placeholder(tf.float32, shape=(1, 224, 224, 3)) + graph, _ = vgg_19(inputs, mode=tf.estimator.ModeKeys.PREDICT) + assert len(tf.trainable_variables()) == 0 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + + # Testing VGG 16 training mode + inputs = tf.placeholder(tf.float32, shape=(1, 224, 224, 3)) + graph, _ = vgg_16(inputs) + assert len(tf.trainable_variables()) == 32 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + + + # Testing VGG 16 predicting mode + inputs = tf.placeholder(tf.float32, shape=(1, 224, 224, 3)) + graph, _ = vgg_16(inputs, mode=tf.estimator.ModeKeys.PREDICT) + assert len(tf.trainable_variables()) == 0 + + tf.reset_default_graph() + assert len(tf.global_variables()) == 0 + diff --git a/bob/learn/tensorflow/test/test_image_dataset.py b/bob/learn/tensorflow/test/test_image_dataset.py index 4d65df1013853968ad3b285b044e529006c4d120..9efb8f5791b011e26620fe5f05fa4703e284cdf3 100644 --- a/bob/learn/tensorflow/test/test_image_dataset.py +++ b/bob/learn/tensorflow/test/test_image_dataset.py @@ -42,8 +42,6 @@ def test_logitstrainer_images(): run_logitstrainer_images(trainer) finally: try: - os.unlink(tfrecord_train) - os.unlink(tfrecord_validation) shutil.rmtree(model_dir, ignore_errors=True) except Exception: pass diff --git a/bob/learn/tensorflow/test/test_style_transfer.py b/bob/learn/tensorflow/test/test_style_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..44bd3c99c1cff2c0c54d528b7ce32e3c7fe0735b --- /dev/null +++ b/bob/learn/tensorflow/test/test_style_transfer.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import os +import shutil +from glob import glob +from tempfile import mkdtemp +from click.testing import CliRunner +from bob.io.base.test_utils import datafile +import pkg_resources + +import tensorflow as tf + +from bob.learn.tensorflow.utils import load_mnist, create_mnist_tfrecord +from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator +from bob.learn.tensorflow.loss import mean_cross_entropy_loss +from bob.learn.tensorflow.utils import reproducible +from .test_estimator_onegraph import run_logitstrainer_mnist + +from bob.learn.tensorflow.estimators import Logits +from bob.learn.tensorflow.network import dummy +from bob.learn.tensorflow.script.style_transfer import style_transfer + +dummy_config = datafile('style_transfer.py', __name__) +CONFIG = ''' +from bob.learn.tensorflow.network import dummy +architecture = dummy +import pkg_resources + +checkpoint_dir = "./temp/" + +style_end_points = ["conv1"] +content_end_points = ["fc1"] + +scopes = {"Dummy/":"Dummy/"} + +''' + + +#tfrecord_train = "./train_mnist.tfrecord" +model_dir = "./temp" +output_style_image = 'output_style.png' + +learning_rate = 0.1 +data_shape = (28, 28, 1) # size of atnt images +data_type = tf.float32 +batch_size = 32 +epochs = 1 +steps = 100 + + +def test_style_transfer(): + with open(dummy_config, 'w') as f: + f.write(CONFIG) + + # Trainer logits + + # CREATING FAKE MODEL USING MNIST + _, run_config,_,_,_ = reproducible.set_seed() + trainer = Logits( + model_dir=model_dir, + architecture=dummy, + optimizer=tf.train.GradientDescentOptimizer(learning_rate), + n_classes=10, + loss_op=mean_cross_entropy_loss, + config=run_config) + run_logitstrainer_mnist(trainer) + + # Style transfer using this fake model + runner = CliRunner() + result = runner.invoke(style_transfer, + args=[pkg_resources.resource_filename( __name__, 'data/dummy_image_database/m301_01_p01_i0_0_GRAY.png'), + output_style_image, dummy_config]) + + try: + os.unlink(dummy_config) + shutil.rmtree(model_dir, ignore_errors=True) + except Exception: + pass + + diff --git a/conda/meta.yaml b/conda/meta.yaml index 27cb64b0bfe95ed91f433c0c9e0d89b74080b684..7ae2cb359f98781509e17556cf2128c8bd27739e 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -37,6 +37,7 @@ requirements: - bob.db.mnist - bob.db.atnt - bob.bio.base + - bob.ip.color run: - python - setuptools diff --git a/setup.py b/setup.py index 8ec07c7dfa05fd96d5380ea64d31693ecfe6997c..8a0a51a7ac58ba078cd3894021d099a733bda25f 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ setup( 'predict_bio = bob.learn.tensorflow.script.predict_bio:predict_bio', 'train = bob.learn.tensorflow.script.train:train', 'train_and_evaluate = bob.learn.tensorflow.script.train_and_evaluate:train_and_evaluate', + 'style_transfer = bob.learn.tensorflow.script.style_transfer:style_transfer' ], },