style_transfer.py 7.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python
"""Trains networks using Tensorflow estimators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import click
import tensorflow as tf
from bob.extension.scripts.click_helper import (verbosity_option,
                                                ConfigCommand, ResourceOption)
import bob.io.image
import bob.io.base
import numpy
import bob.ip.base
import bob.ip.color
import sys

from bob.learn.tensorflow.style_transfer import compute_features, compute_gram
from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, denoising_loss


logger = logging.getLogger(__name__)

def normalize4save(img):
    return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8")

28
29
30
#@click.argument('style_image_path', required=True)


31
32
@click.command(
    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
33
34
35
36
37
38
39
40
@click.argument('content_image_path', required=True)
@click.argument('output_path', required=True)
@click.option('--style-image-paths',
              cls=ResourceOption,
              required=True,
              multiple=True,              
              entry_point_group='bob.learn.tensorflow.style_images',
              help='List of images that encods the style.')
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@click.option('--architecture',
              '-a',
              required=True,
              cls=ResourceOption,
              entry_point_group='bob.learn.tensorflow.architecture',
              help='The base architecure.')
@click.option('--checkpoint-dir',
              '-c',
              required=True,
              cls=ResourceOption,
              help='The base architecure.')
@click.option('--iterations',
              '-i',
              type=click.types.INT,
              help='Number of steps for which to train model.',
              default=1000)
@click.option('--learning_rate',
              '-i',
              type=click.types.FLOAT,
              help='Learning rate.',
              default=1.)
@click.option('--content-weight',
              type=click.types.FLOAT,
              help='Weight of the content loss.',
              default=1.)
@click.option('--style-weight',
              type=click.types.FLOAT,
              help='Weight of the style loss.',
              default=1000.)
@click.option('--denoise-weight',
              type=click.types.FLOAT,
              help='Weight denoising loss.',
              default=100.)
@click.option('--content-end-points',
              cls=ResourceOption,
              multiple=True,
              entry_point_group='bob.learn.tensorflow.end_points',
              help='List of end_points for the used to encode the content')
@click.option('--style-end-points',
              cls=ResourceOption,
              multiple=True,
              entry_point_group='bob.learn.tensorflow.end_points',
              help='List of end_points for the used to encode the style')
@click.option('--scopes',
              cls=ResourceOption,
              entry_point_group='bob.learn.tensorflow.scopes',
              help='Dictionary containing the mapping scores',
              required=True)
89
90
91
92
93
@click.option('--pure-noise',
               is_flag=True,
               help="If set will save the raw noisy generated image."
                    "If not set, the output will be RGB = stylizedYUV.Y, originalYUV.U, originalYUV.V"
              )
94
@verbosity_option(cls=ResourceOption)
95
def style_transfer(content_image_path, output_path, style_image_paths,
96
97
98
                   architecture, checkpoint_dir,
                   iterations, learning_rate,
                   content_weight, style_weight, denoise_weight, content_end_points,
99
                   style_end_points, scopes, pure_noise,  **kwargs):
100
101
102
103
104
105
    """
     Trains neural style transfer from

    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

    """
106

107
    # Reading and converting to the tensorflow format
108
109
110
111
112
113
114
    content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image_path))
    style_images = []
    for path in style_image_paths:
        image = bob.io.image.to_matplotlib(bob.io.base.load(path))
        style_images.append(numpy.reshape(image, (1, image.shape[0],
                                                     image.shape[1],
                                                     image.shape[2])))
115
116
117
118
119
120
121
122
123
124
125

    # Reshaping to NxWxHxC
    content_image = numpy.reshape(content_image, (1, content_image.shape[0],
                                                     content_image.shape[1],
                                                     content_image.shape[2]))

    # Base content features
    content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points)

    # Base style features
    # TODO: Enable a set of style images
126
127
128
129
    style_grams = []
    for image in style_images:
        style_features = compute_features(image, architecture, checkpoint_dir, style_end_points)
        style_grams.append(compute_gram(style_features))
130
131
132

    # Organizing the trainer
    with tf.Graph().as_default():
133
134
        tf.set_random_seed(0)

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        # Random noise
        noise = tf.Variable(tf.random_normal(shape=content_image.shape),
                            trainable=True)

        _, end_points = architecture(noise,
                                      mode=tf.estimator.ModeKeys.PREDICT,
                                      trainable_variables=[])

        # Computing content loss
        content_noises = []
        for c in content_end_points:
            content_noises.append(end_points[c])
        c_loss = content_loss(content_noises, content_features)

        # Computing style_loss
        style_gram_noises = []
151
152
153
154
155
156
157
158
159
160
        s_loss = 0
        for grams_per_image in style_grams:

            for c in style_end_points:
                layer = end_points[c]
                _, height, width, number = map(lambda i: i.value, layer.get_shape())
                size = height * width * number
                features = tf.reshape(layer, (-1, number))
                style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size)
            s_loss += linear_gram_style_loss(style_gram_noises, grams_per_image)
161
162
163
164
165
166
167
168
169
170
171

        # Variation denoise
        d_loss = denoising_loss(noise)

        #Total loss
        total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss

        solver = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

        tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir),
                                                  scopes)
172
        # Training
173
174
175
176
177
178
179
180
        with tf.Session() as sess: 
            sess.run(tf.global_variables_initializer())

            for i in range(iterations):
                _, loss = sess.run([solver, total_loss])
                print("Iteration {0}, loss {1}".format(i, loss))
                sys.stdout.flush()

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
            # Saving generated image
            raw_style_image = sess.run(noise)[0, :, :,:]
            raw_style_image = bob.io.image.to_bob(raw_style_image)
            normalized_style_image = normalize4save(raw_style_image)

            if pure_noise:
                bob.io.base.save(normalized_style_image, output_path)
            else:
                # Original output
                normalized_style_image_yuv = bob.ip.color.rgb_to_yuv(bob.ip.color.gray_to_rgb(bob.ip.color.rgb_to_gray(normalized_style_image)))
                
                content_image_yuv = bob.ip.color.rgb_to_yuv(bob.io.base.load(content_image_path))
                
                output_image = numpy.zeros(shape=content_image_yuv.shape, dtype="uint8")
                output_image[0,:,:] = normalized_style_image_yuv[0,:,:]
                output_image[1,:,:] = content_image_yuv[1,:,:]
                output_image[2,:,:] = content_image_yuv[2,:,:]

                output_image = bob.ip.color.yuv_to_rgb(output_image)
                bob.io.base.save(output_image, output_path)
201