style_transfer.py 6.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python
"""Trains networks using Tensorflow estimators.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import click
import tensorflow as tf
from bob.extension.scripts.click_helper import (verbosity_option,
                                                ConfigCommand, ResourceOption)
import bob.io.image
import bob.io.base
import numpy
import bob.ip.base
import bob.ip.color
import sys

from bob.learn.tensorflow.style_transfer import compute_features, compute_gram
from bob.learn.tensorflow.loss import linear_gram_style_loss, content_loss, denoising_loss


logger = logging.getLogger(__name__)

def normalize4save(img):
    return (255 * ((img - numpy.min(img)) / (numpy.max(img)-numpy.min(img)))).astype("uint8")

@click.command(
    entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand)
@click.argument('content_image', required=True)
@click.argument('style_image', required=True)
@click.argument('output', required=True)
@click.option('--architecture',
              '-a',
              required=True,
              cls=ResourceOption,
              entry_point_group='bob.learn.tensorflow.architecture',
              help='The base architecure.')
@click.option('--checkpoint-dir',
              '-c',
              required=True,
              cls=ResourceOption,
              help='The base architecure.')
@click.option('--iterations',
              '-i',
              type=click.types.INT,
              help='Number of steps for which to train model.',
              default=1000)
@click.option('--learning_rate',
              '-i',
              type=click.types.FLOAT,
              help='Learning rate.',
              default=1.)
@click.option('--content-weight',
              type=click.types.FLOAT,
              help='Weight of the content loss.',
              default=1.)
@click.option('--style-weight',
              type=click.types.FLOAT,
              help='Weight of the style loss.',
              default=1000.)
@click.option('--denoise-weight',
              type=click.types.FLOAT,
              help='Weight denoising loss.',
              default=100.)
@click.option('--content-end-points',
              cls=ResourceOption,
              multiple=True,
              entry_point_group='bob.learn.tensorflow.end_points',
              help='List of end_points for the used to encode the content')
@click.option('--style-end-points',
              cls=ResourceOption,
              multiple=True,
              entry_point_group='bob.learn.tensorflow.end_points',
              help='List of end_points for the used to encode the style')
@click.option('--scopes',
              cls=ResourceOption,
              entry_point_group='bob.learn.tensorflow.scopes',
              help='Dictionary containing the mapping scores',
              required=True)
@verbosity_option(cls=ResourceOption)
def style_transfer(content_image, style_image, output,
                   architecture, checkpoint_dir,
                   iterations, learning_rate,
                   content_weight, style_weight, denoise_weight, content_end_points,
                   style_end_points, scopes,  **kwargs):
    """
     Trains neural style transfer from

    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

    """
 
    # Reading and converting to the tensorflow format
    content_image = bob.io.image.to_matplotlib(bob.io.base.load(content_image))
    style_image = bob.io.image.to_matplotlib(bob.io.base.load(style_image))

    # Reshaping to NxWxHxC
    content_image = numpy.reshape(content_image, (1, content_image.shape[0],
                                                     content_image.shape[1],
                                                     content_image.shape[2]))

    style_image = numpy.reshape(style_image, (1, style_image.shape[0],
                                                 style_image.shape[1],
                                                 style_image.shape[2]))

    # Base content features
    content_features = compute_features(content_image, architecture, checkpoint_dir, content_end_points)

    # Base style features
    # TODO: Enable a set of style images
    style_features = compute_features(style_image, architecture, checkpoint_dir, style_end_points)
    style_grams = compute_gram(style_features)

    # Organizing the trainer
    with tf.Graph().as_default():
        # Random noise
        noise = tf.Variable(tf.random_normal(shape=content_image.shape),
                            trainable=True)

        _, end_points = architecture(noise,
                                      mode=tf.estimator.ModeKeys.PREDICT,
                                      trainable_variables=[])

        # Computing content loss
        content_noises = []
        for c in content_end_points:
            content_noises.append(end_points[c])
        c_loss = content_loss(content_noises, content_features)

        # Computing style_loss
        style_gram_noises = []
        for c in style_end_points:
            layer = end_points[c]
            _, height, width, number = map(lambda i: i.value, layer.get_shape())
            size = height * width * number
            features = tf.reshape(layer, (-1, number))
            style_gram_noises.append(tf.matmul(tf.transpose(features), features) / size)
        s_loss = linear_gram_style_loss(style_gram_noises, style_grams)

        # Variation denoise
        d_loss = denoising_loss(noise)


        #Total loss
        total_loss = content_weight*c_loss + style_weight*s_loss + denoise_weight*d_loss

        solver = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

        tf.contrib.framework.init_from_checkpoint(tf.train.latest_checkpoint(checkpoint_dir),
                                                  scopes)

        with tf.Session() as sess: 
            sess.run(tf.global_variables_initializer())

            for i in range(iterations):
                _, loss = sess.run([solver, total_loss])
                print("Iteration {0}, loss {1}".format(i, loss))
                sys.stdout.flush()

            style_image = sess.run(noise)[0, :, :,:]
            style_image = bob.io.image.to_bob(style_image) 
            bob.io.base.save(normalize4save(style_image), output)