style_transfer.py 9.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
"""Trains networks using Tensorflow estimators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import click
import tensorflow as tf
from bob.extension.scripts.click_helper import (verbosity_option,
                                                ConfigCommand, ResourceOption)
import bob.io.image
import bob.io.base
import numpy
import bob.ip.base
import bob.ip.color
import sys

from bob.learn.tensorflow.style_transfer import compute_features, compute_gram
from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, denoising_loss


logger = logging.getLogger(__name__)

def normalize4save(img):
    return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8")

28
29
30
#@click.argument('style_image_path', required=True)


31
32
@click.command(
    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
33
34
35
36
37
38
39
40
@click.argument('content_image_path', required=True)
@click.argument('output_path', required=True)
@click.option('--style-image-paths',
              cls=ResourceOption,
              required=True,
              multiple=True,              
              entry_point_group='bob.learn.tensorflow.style_images',
              help='List of images that encods the style.')
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@click.option('--architecture',
              '-a',
              required=True,
              cls=ResourceOption,
              entry_point_group='bob.learn.tensorflow.architecture',
              help='The base architecure.')
@click.option('--checkpoint-dir',
              '-c',
              required=True,
              cls=ResourceOption,
              help='The base architecure.')
@click.option('--iterations',
              '-i',
              type=click.types.INT,
              help='Number of steps for which to train model.',
              default=1000)
@click.option('--learning_rate',
              '-i',
              type=click.types.FLOAT,
              help='Learning rate.',
              default=1.)
@click.option('--content-weight',
              type=click.types.FLOAT,
              help='Weight of the content loss.',
              default=1.)
@click.option('--style-weight',
              type=click.types.FLOAT,
              help='Weight of the style loss.',
              default=1000.)
@click.option('--denoise-weight',
              type=click.types.FLOAT,
              help='Weight denoising loss.',
              default=100.)
@click.option('--content-end-points',
              cls=ResourceOption,
              multiple=True,
              entry_point_group='bob.learn.tensorflow.end_points',
              help='List of end_points for the used to encode the content')
@click.option('--style-end-points',
              cls=ResourceOption,
              multiple=True,
              entry_point_group='bob.learn.tensorflow.end_points',
              help='List of end_points for the used to encode the style')
@click.option('--scopes',
              cls=ResourceOption,
              entry_point_group='bob.learn.tensorflow.scopes',
              help='Dictionary containing the mapping scores',
              required=True)
89
90
91
92
93
@click.option('--pure-noise',
               is_flag=True,
               help="If set will save the raw noisy generated image."
                    "If not set, the output will be RGB = stylizedYUV.Y, originalYUV.U, originalYUV.V"
              )
94
@verbosity_option(cls=ResourceOption)
95
def style_transfer(content_image_path, output_path, style_image_paths,
96
97
98
                   architecture, checkpoint_dir,
                   iterations, learning_rate,
                   content_weight, style_weight, denoise_weight, content_end_points,
99
                   style_end_points, scopes, pure_noise,  **kwargs):
100
    """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
101
     Trains neural style transfer using the approach presented in:
102
103
104

    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    \b

    If you want run a style transfer using InceptionV2 as basis folo

    Below follow a CONFIG template
    
    CONFIG.PY
    ```

       from bob.extension import rc

       from bob.learn.tensorflow.network import inception_resnet_v2_batch_norm
       architecture = inception_resnet_v2_batch_norm

       checkpoint_dir = rc["bob.bio.face_ongoing.idiap_casia_inception_v2_centerloss_rgb"]

       style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3", "Conv2d_3b_1x1", "Conv2d_4a_3x3"]

       content_end_points = ["Bottleneck", "PreLogitsFlatten"]

       scopes = {"InceptionResnetV2/":"InceptionResnetV2/"}

    ```
    \b

    Then run::

       $ bob tf style <content-image> <output-image> --style-image-paths <style-image> CONFIG.py


    You can also provide a list of images to encode the style using the config file as in the example below.

    CONFIG.PY
    ```

       from bob.extension import rc

       from bob.learn.tensorflow.network import inception_resnet_v2_batch_norm
       architecture = inception_resnet_v2_batch_norm

       checkpoint_dir = rc["bob.bio.face_ongoing.idiap_casia_inception_v2_centerloss_rgb"]

       style_end_points = ["Conv2d_1a_3x3", "Conv2d_2b_3x3", "Conv2d_3b_1x1", "Conv2d_4a_3x3"]

       content_end_points = ["Bottleneck", "PreLogitsFlatten"]

       scopes = {"InceptionResnetV2/":"InceptionResnetV2/"}

       style_image_paths = ["STYLE_1.png",
                            "STYLE_2.png"]

    ```
 
    Then run::

       $ bob tf style <content-image> <output-image> CONFIG.py

    \b \b

164
    """
165

166
    # Reading and converting to the tensorflow format
167
168
169
170
171
172
173
    content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image_path))
    style_images = []
    for path in style_image_paths:
        image = bob.io.image.to_matplotlib(bob.io.base.load(path))
        style_images.append(numpy.reshape(image, (1, image.shape[0],
                                                     image.shape[1],
                                                     image.shape[2])))
174
175
176
177
178
179
180
181
182
183
184

    # Reshaping to NxWxHxC
    content_image = numpy.reshape(content_image, (1, content_image.shape[0],
                                                     content_image.shape[1],
                                                     content_image.shape[2]))

    # Base content features
    content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points)

    # Base style features
    # TODO: Enable a set of style images
185
186
187
188
    style_grams = []
    for image in style_images:
        style_features = compute_features(image, architecture, checkpoint_dir, style_end_points)
        style_grams.append(compute_gram(style_features))
189
190
191

    # Organizing the trainer
    with tf.Graph().as_default():
192
193
        tf.set_random_seed(0)

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        # Random noise
        noise = tf.Variable(tf.random_normal(shape=content_image.shape),
                            trainable=True)

        _, end_points = architecture(noise,
                                      mode=tf.estimator.ModeKeys.PREDICT,
                                      trainable_variables=[])

        # Computing content loss
        content_noises = []
        for c in content_end_points:
            content_noises.append(end_points[c])
        c_loss = content_loss(content_noises, content_features)

        # Computing style_loss
        style_gram_noises = []
210
211
212
213
214
215
216
217
218
219
        s_loss = 0
        for grams_per_image in style_grams:

            for c in style_end_points:
                layer = end_points[c]
                _, height, width, number = map(lambda i: i.value, layer.get_shape())
                size = height * width * number
                features = tf.reshape(layer, (-1, number))
                style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size)
            s_loss += linear_gram_style_loss(style_gram_noises, grams_per_image)
220
221
222
223
224
225
226
227
228
229
230

        # Variation denoise
        d_loss = denoising_loss(noise)

        #Total loss
        total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss

        solver = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

        tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir),
                                                  scopes)
231
        # Training
232
233
234
235
236
237
238
239
        with tf.Session() as sess: 
            sess.run(tf.global_variables_initializer())

            for i in range(iterations):
                _, loss = sess.run([solver, total_loss])
                print("Iteration {0}, loss {1}".format(i, loss))
                sys.stdout.flush()

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            # Saving generated image
            raw_style_image = sess.run(noise)[0, :, :,:]
            raw_style_image = bob.io.image.to_bob(raw_style_image)
            normalized_style_image = normalize4save(raw_style_image)

            if pure_noise:
                bob.io.base.save(normalized_style_image, output_path)
            else:
                # Original output
                normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image)))
                
                content_image_yuv = bob.ip.color.rgb_to_yuv(bob.io.base.load(content_image_path))
                
                output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8")
                output_image[0,:,:] = normalized_style_image_yuv[0,:,:]
                output_image[1,:,:] = content_image_yuv[1,:,:]
                output_image[2,:,:] = content_image_yuv[2,:,:]

                output_image = bob.ip.color.yuv_to_rgb(output_image)
                bob.io.base.save(output_image, output_path)
260