image.py 4.27 KB
Newer Older
1
import numpy as np
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import tensorflow as tf


def to_channels_last(image):
    """Converts the image to channel_last format. This is the same format as in
    matplotlib, skimage, and etc.

    Parameters
    ----------
    image : `tf.Tensor`
        At least a 3 dimensional image. If the dimension is more than 3, the
        last 3 dimensions are assumed to be [C, H, W].

    Returns
    -------
    image : `tf.Tensor`
        The image in [..., H, W, C] format.

    Raises
    ------
    ValueError
        If dim of image is less than 3.
    """
25
    ndim = len(image.shape)
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    if ndim < 3:
        raise ValueError(
            "The image needs to be at least 3 dimensional but it " "was {}".format(ndim)
        )
    axis_order = [1, 2, 0]
    shift = ndim - 3
    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
    return tf.transpose(a=image, perm=axis_order)


def to_channels_first(image):
    """Converts the image to channel_first format. This is the same format as
    in bob.io.image and bob.io.video.

    Parameters
    ----------
    image : `tf.Tensor`
        At least a 3 dimensional image. If the dimension is more than 3, the
        last 3 dimensions are assumed to be [H, W, C].

    Returns
    -------
    image : `tf.Tensor`
        The image in [..., C, H, W] format.

    Raises
    ------
    ValueError
        If dim of image is less than 3.
    """
56
    ndim = len(image.shape)
57
58
59
60
61
62
63
64
    if ndim < 3:
        raise ValueError(
            "The image needs to be at least 3 dimensional but it " "was {}".format(ndim)
        )
    axis_order = [2, 0, 1]
    shift = ndim - 3
    axis_order = list(range(ndim - 3)) + [n + shift for n in axis_order]
    return tf.transpose(a=image, perm=axis_order)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165


def blocks_tensorflow(images, block_size):
    """Return all non-overlapping blocks of an image using tensorflow
    operations.

    Parameters
    ----------
    images : `tf.Tensor`
        The input color images. It is assumed that the image has a shape of
        [?, H, W, C].
    block_size : (int, int)
        A tuple of two integers indicating the block size.

    Returns
    -------
    blocks : `tf.Tensor`
        All the blocks in the batch dimension. The output will be of
        size [?, block_size[0], block_size[1], C].
    n_blocks : int
        The number of blocks that was obtained per image.
    """
    # normalize block_size
    block_size = [1] + list(block_size) + [1]
    output_size = list(block_size)
    output_size[0] = -1
    output_size[-1] = images.shape[-1]
    blocks = tf.image.extract_patches(
        images, block_size, block_size, [1, 1, 1, 1], "VALID"
    )
    n_blocks = int(np.prod(blocks.shape[1:3]))
    output = tf.reshape(blocks, output_size)
    return output, n_blocks


def tf_repeat(tensor, repeats):
    """
    Parameters
    ----------
    tensor
        A Tensor. 1-D or higher.
    repeats
        A list. Number of repeat for each dimension, length must be the same as
        the number of dimensions in input

    Returns
    -------
    A Tensor. Has the same type as input. Has the shape of tensor.shape *
    repeats
    """
    with tf.name_scope("repeat"):
        expanded_tensor = tf.expand_dims(tensor, -1)
        multiples = [1] + repeats
        tiled_tensor = tf.tile(expanded_tensor, multiples=multiples)
        repeated_tesnor = tf.reshape(tiled_tensor, tf.shape(input=tensor) * repeats)
    return repeated_tesnor


def all_patches(image, label, key, size):
    """Extracts all patches of an image

    Parameters
    ----------
    image:
        The image should be channels_last format and already batched.

    label:
        The label for the image

    key:
        The key for the image

    size: (int, int)
        The height and width of the blocks.

    Returns
    -------
    blocks:
       The non-overlapping blocks of size from image and labels and keys are
       repeated.

    label:

    key:
    """
    blocks, n_blocks = blocks_tensorflow(image, size)

    # duplicate label and key as n_blocks
    def repeats(shape):
        r = list(shape)
        for i in range(len(r)):
            if i == 0:
                r[i] = n_blocks
            else:
                r[i] = 1
        return r

    label = tf_repeat(label, repeats(label.shape))
    key = tf_repeat(key, repeats(key.shape))

    return blocks, label, key