neural_transfer.py 2.33 KB
Newer Older
1
2
3
4
5
6
7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>


import tensorflow as tf
import numpy
8
import os
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

def compute_features(input_image, architecture, checkpoint_dir, target_end_points):
    """
    For a given set of end_points, convolve the input image until these points

    Parameters
    ----------

    input_image: numpy.array
        Input image in the format WxHxC

    architecture:
        Pointer to the architecture function

    checkpoint_dir: str
        DCNN checkpoint directory

    end_points: dict
       Dictionary containing the end point tensors

    """
30

31
32
33
    input_pl = tf.placeholder('float32', shape=(1, input_image.shape[1],
                                                   input_image.shape[2],
                                                   input_image.shape[3]))
34

35
    # TODO: Think on how abstract this normalization operation
36
    _, end_points = architecture(tf.stack([tf.image.per_image_standardization(i) for i in tf.unstack(input_pl)]), mode=tf.estimator.ModeKeys.PREDICT, trainable_variables=None)
37
38
39
40
41

    with tf.Session() as sess:
        # Restoring the checkpoint for the given architecture
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
42
43
44
45
46

        if os.path.isdir(checkpoint_dir):
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
        else:
            saver.restore(sess, checkpoint_dir)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

        #content_feature = sess.run(end_points[CONTENT_END_POINTS], feed_dict={input_image: content_image})
        features = []
        for ep in target_end_points:
            feature = sess.run(end_points[ep], feed_dict={input_pl: input_image})
            features.append(feature)

    # Killing the graph
    tf.reset_default_graph()
    return features


def compute_gram(features):
    """
    Given a list of features (as numpy.arrays) comput the gram matrices of each
    pinning the channel as in:

    Gatys, Leon A., Alexander S. Ecker, and Matthias Bethge. "A neural algorithm of artistic style." arXiv preprint arXiv:1508.06576 (2015).

    Parameters
    ----------

    features: numpy.array
      Convolved features in the format NxWxHxC

    """

    grams = []
    for f in features:
        f = numpy.reshape(f, (-1, f.shape[3]))
        grams.append(numpy.matmul(f.T, f) / f.size)

    return grams