neural_transfer.py 9.87 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>


import tensorflow as tf
import numpy
8
import os
9 10 11 12 13 14 15
from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, denoising_loss
import bob.io.image
import bob.ip.color

import logging
logger = logging.getLogger(__name__)

16

17
def compute_features(input_image, architecture, checkpoint_dir, target_end_points, preprocess_fn=None):
18 19 20 21 22 23
    """
    For a given set of end_points, convolve the input image until these points

    Parameters
    ----------

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
24
    input_image: :any:`numpy.array`
25 26 27 28 29 30 31 32 33 34 35
        Input image in the format WxHxC

    architecture:
        Pointer to the architecture function

    checkpoint_dir: str
        DCNN checkpoint directory

    end_points: dict
       Dictionary containing the end point tensors

36 37 38
    preprocess_fn:
       Pointer to a preprocess function

39
    """
40

41 42 43
    input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1],
                                                   input_image.shape[2],
                                                   input_image.shape[3]))
44

45 46 47 48
    if preprocess_fn is None:
        _, end_points = architecture(input_pl, mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
    else:
        _, end_points = architecture(tf.stack([preprocess_fn(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
49 50 51 52
    with tf.Session() as sess:
        # Restoring the checkpoint for the given architecture
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
53 54 55 56 57

        if os.path.isdir(checkpoint_dir):
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
        else:
            saver.restore(sess, checkpoint_dir)
58 59

        #content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image})
60
        features = sess.run([end_points[ep] for ep in target_end_points], feed_dict={input_pl: input_image})
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76

    # Killing the graph
    tf.reset_default_graph()
    return features


def compute_gram(features):
    """
    Given a list of features (as numpy.arrays) comput the gram matrices of each
    pinning the channel as in:

    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

    Parameters
    ----------

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
77
    features: :any:`numpy.array`
78 79 80 81 82 83 84 85 86 87 88
      Convolved features in the format NxWxHxC

    """

    grams = []
    for f in features:
        f = numpy.reshape(f, (-1, f.shape[3]))
        grams.append(numpy.matmul(f.T, f) / f.size)

    return grams

89

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
90
def do_style_transfer(content_image, style_images,
91 92 93 94
                      architecture, checkpoint_dir, scopes,
                      content_end_points, style_end_points,
                      preprocess_fn=None, un_preprocess_fn=None, pure_noise=False,
                      iterations=1000, learning_rate=0.1,
95
                      content_weight=5., style_weight=500., denoise_weight=500., start_from="noise"):
96 97 98 99 100 101 102 103 104

    """
    Trains neural style transfer using the approach presented in:

    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

    Parameters
    ----------

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
105
    content_image: :any:`numpy.array`
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
       Content image in the Bob format (C x W x H)

    style_images: :any:`list`
       List of numpy.array (Bob format (C x W x H)) that encodes the style

    architecture:
       Point to a function with the base architecture

    checkpoint_dir:
       CNN checkpoint path

    scopes:
       Dictionary containing the mapping scores

    content_end_points:
       List of end_points (from the architecture) for the used to encode the content

    style_end_points:
       List of end_points (from the architecture) for the used to encode the style

    preprocess_fn:
       Preprocess function. Pointer to a function that preprocess the INPUT signal
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
128

129 130 131 132 133 134 135 136 137
    unpreprocess_fn:
       Un preprocess function. Pointer to a function that preprocess the OUTPUT signal

    pure_noise:
       If set will save the raw noisy generated image.
       If not set, the output will be RGB = stylizedYUV.Y, originalYUV.U, originalYUV.V

    iterations:
       Number of iterations to generate the image
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
138

139 140 141 142 143 144 145 146
    learning_rate:
       Adam learning rate

    content_weight:
       Weight of the content loss

    style_weight:
       Weight of the style loss
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
147

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
    denoise_weight:
       Weight denoising loss
    """

    def wise_shape(shape):
        if len(shape)==2:
            return (1, shape[0], shape[1], 1)
        else:
            return (1, shape[0], shape[1], shape[2])

    def normalize4save(img):
        return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8")

    # Reshaping to NxWxHxC and converting to the tensorflow format
    # content
163
    original_image = content_image
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
    content_image = bob.io.image.to_matplotlib(content_image).astype("float32")
    content_image = numpy.reshape(content_image, wise_shape(content_image.shape))

    # and style
    for i in range(len(style_images)):
        image = bob.io.image.to_matplotlib(style_images[i])
        image = numpy.reshape(image, wise_shape(image.shape))
        style_images[i] = image

    # Base content features
    logger.info("Computing content features")
    content_features = compute_features(content_image, architecture, checkpoint_dir,
                                        content_end_points, preprocess_fn)

    # Base style features
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
179
    logger.info("Computing style features")
180 181
    style_grams = []
    for image in style_images:
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
182
        style_features = compute_features(image, architecture, checkpoint_dir,
183 184 185 186 187 188 189 190 191
                                          style_end_points, preprocess_fn)
        style_grams.append(compute_gram(style_features))

    # Organizing the trainer
    logger.info("Training.....")
    with tf.Graph().as_default():
        tf.set_random_seed(0)

        # Random noise
192 193 194 195 196 197 198 199 200 201
        if start_from == "noise":
            starting_image = tf.random_normal(shape=content_image.shape) * 0.256
        elif start_from == "content":
            starting_image = preprocess_fn(content_image)
        elif start_from == "style":
            starting_image = preprocess_fn(style_images[0])
        else:
            raise ValueError(f"Unknown starting image: {start_from}")

        noise = tf.Variable(starting_image, dtype="float32", trainable=True)
202 203 204 205 206 207 208 209 210 211 212 213 214
        _, end_points = architecture(noise,
                                      mode=tf.estimator.ModeKeys.PREDICT,
                                      trainable_variables=[])

        # Computing content loss
        content_noises = []
        for c in content_end_points:
            content_noises.append(end_points[c])
        c_loss = content_loss(content_noises, content_features)

        # Computing style_loss
        style_gram_noises = []
        s_loss = 0
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
215

216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        for grams_per_image in style_grams:

            for c in style_end_points:
                layer = end_points[c]
                _, height, width, number = map(lambda i: i.value, layer.get_shape())
                size = height * width * number
                features = tf.reshape(layer, (-1, number))
                style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size)
            s_loss += linear_gram_style_loss(style_gram_noises, grams_per_image)

        # Variation denoise
        d_loss = denoising_loss(noise)

        #Total loss
        total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss

        solver = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

        tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir) if os.path.isdir(checkpoint_dir) else checkpoint_dir, scopes)
        # Training
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
236
        with tf.Session() as sess:
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
            sess.run(tf.global_variables_initializer())

            for i in range(iterations):
                _, loss = sess.run([solver, total_loss])
                logger.info("Iteration {0}, loss {1}".format(i, loss))

            # Saving generated image
            raw_style_image = sess.run(noise)[0, :, :,:]
            # Unpreprocessing the signal
            if un_preprocess_fn is not None:
                raw_style_image = un_preprocess_fn(raw_style_image)

            raw_style_image = bob.io.image.to_bob(raw_style_image)
            normalized_style_image = normalize4save(raw_style_image)

            if pure_noise:
                if normalized_style_image.shape[0] == 1:
                    return normalized_style_image[0, :, :]
                else:
                    return normalized_style_image
            else:
                # Original output
                if normalized_style_image.shape[0] == 1:
                    normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(normalized_style_image[0,:,:]))
                    # Loading the content image and clipping from 0-255 in case is in another scale
262 263
                    #scaled_content_image = normalize4save(bob.io.base.load(content_image_path).astype("float32")).astype("float64")
                    scaled_content_image = original_image.astype("float64")
264 265 266
                    content_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(scaled_content_image))
                else:
                    normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image)))
267
                    content_image_yuv = bob.ip.color.rgb_to_yuv(original_image)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
268

269 270 271 272 273 274 275 276
                output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8")
                output_image[0,:,:] = normalized_style_image_yuv[0,:,:]
                output_image[1,:,:] = content_image_yuv[1,:,:]
                output_image[2,:,:] = content_image_yuv[2,:,:]

                output_image = bob.ip.color.yuv_to_rgb(output_image)
                return output_image