Base.py 9 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import tensorflow as tf
import bob.ip.base
9
import numpy
10
import six
11
12
13


class Base(object):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
14
15
16
17
    """
     The class provide base functionalities to shuffle the data to train a neural network

     **Parameters**
18

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
     data:
       Input data to be trainer

     labels:
       Labels. These labels should be set from 0..1

     input_shape:
       The shape of the inputs

     input_dtype:
       The type of the data,

     batch_size:
       Batch size

     seed:
       The seed of the random number generator

     data_augmentation:
       The algorithm used for data augmentation. Look :py:class:`bob.learn.tensorflow.datashuffler.DataAugmentation`

     normalizer:
       The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset`
Tiago Pereira's avatar
Tiago Pereira committed
42
43
44
45
46
47
       
     prefetch:
        Do prefetch?
        
     prefetch_capacity:
        
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49
50
    """

51
    def __init__(self, data, labels,
Tiago Pereira's avatar
Tiago Pereira committed
52
                 input_shape=[None, 28, 28, 1],
Tiago Pereira's avatar
Tiago Pereira committed
53
                 input_dtype="float32",
Tiago Pereira's avatar
Tiago Pereira committed
54
                 batch_size=32,
55
                 seed=10,
56
                 data_augmentation=None,
57
                 normalizer=None,
Tiago Pereira's avatar
Tiago Pereira committed
58
                 prefetch=False,
59
60
                 prefetch_capacity=50,
                 prefetch_threads=5):
Tiago Pereira's avatar
Tiago Pereira committed
61
62

        # Setting the seed for the pseudo random number generator
63
64
        self.seed = seed
        numpy.random.seed(seed)
65

66
        self.normalizer = normalizer
67
68
69
70
71
        self.input_dtype = input_dtype

        # TODO: Check if the bacth size is higher than the input data
        self.batch_size = batch_size

Tiago Pereira's avatar
Tiago Pereira committed
72
        # Preparing the inputs
73
74
75
76
77
78
79
80
81
82
83
84
        self.data = data
        self.input_shape = tuple(input_shape)
        self.labels = labels
        self.possible_labels = list(set(self.labels))

        # Computing the data samples fro train and validation
        self.n_samples = len(self.labels)

        # Shuffling all the indexes
        self.indexes = numpy.array(range(self.n_samples))
        numpy.random.shuffle(self.indexes)

Tiago Pereira's avatar
Tiago Pereira committed
85
        # Use data data augmentation?
86
87
        self.data_augmentation = data_augmentation

Tiago Pereira's avatar
Tiago Pereira committed
88
89
90
91
92
        # Preparing placeholders
        self.data_ph = None
        self.label_ph = None
        # Prefetch variables
        self.prefetch = prefetch
Tiago Pereira's avatar
Tiago Pereira committed
93
        self.prefetch_capacity = prefetch_capacity
94
        self.prefetch_threads = prefetch_threads
Tiago Pereira's avatar
Tiago Pereira committed
95
96
        self.data_ph_from_queue = None
        self.label_ph_from_queue = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
97

98
        self.batch_generator = None
99
100
        self.epoch = 0

Tiago Pereira's avatar
Tiago Pereira committed
101
    def create_placeholders(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
102
        """
Tiago Pereira's avatar
Tiago Pereira committed
103
104
105
        Create place holder instances
        
        :return: 
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
106
        """
Tiago Pereira's avatar
Tiago Pereira committed
107
108
109
110
111
112
113
114
115
116
        with tf.name_scope("Input"):

            self.data_ph = tf.placeholder(tf.float32, shape=self.input_shape, name="data")
            self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label")

            # If prefetch, setup the queue to feed data
            if self.prefetch:
                queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
                                     dtypes=[tf.float32, tf.int64],
                                     shapes=[self.input_shape[1:], []])
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
117

Tiago Pereira's avatar
Tiago Pereira committed
118
119
120
121
122
123
124
125
126
                # Fetching the place holders from the queue
                self.enqueue_op = queue.enqueue_many([self.data_ph, self.label_ph])
                self.data_ph_from_queue, self.label_ph_from_queue = queue.dequeue_many(self.batch_size)

            else:
                self.data_ph_from_queue = self.data_ph
                self.label_ph_from_queue = self.label_ph

    def __call__(self, element, from_queue=False):
127
        """
Tiago Pereira's avatar
Tiago Pereira committed
128
129
        Return the necessary placeholder
        
130
131
        """

Tiago Pereira's avatar
Tiago Pereira committed
132
133
134
135
136
137
        if not element in ["data", "label"]:
            raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options))

        # If None, create the placeholders from scratch
        if self.data_ph is None:
            self.create_placeholders()
138

Tiago Pereira's avatar
Tiago Pereira committed
139
140
141
142
143
        if element == "data":
            if from_queue:
                return self.data_ph_from_queue
            else:
                return self.data_ph
144

Tiago Pereira's avatar
Tiago Pereira committed
145
146
147
148
149
        else:
            if from_queue:
                return self.label_ph_from_queue
            else:
                return self.label_ph
150
151
152
153
154
155
156


    def bob2skimage(self, bob_image):
        """
        Convert bob color image to the skcit image
        """

157
        skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
158

159
        for i in range(bob_image.shape[0]):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
160
            skimage[:, :, i] = bob_image[i, :, :]
161
162
163

        return skimage

164
165
166
167
168
169
170
171
172
173
174
175
    def skimage2bob(self, sk_image):
        """
        Convert bob color image to the skcit image
        """

        bob_image = numpy.zeros(shape=(sk_image.shape[2], sk_image.shape[0], sk_image.shape[1]))

        for i in range(bob_image.shape[0]):
            bob_image[i, :, :] = sk_image[:, :, i]  # Copying red

        return bob_image

176
177
178
179
180
181
182
183
184
185
    def rescale(self, data):
        """
        Reescale a single sample with input_shape

        """
        # if self.input_shape != data.shape:
        if self.bob_shape != data.shape:

            # TODO: Implement a better way to do this reescaling
            # If it is gray scale
186
            if self.input_shape[3] == 1:
187
                copy = data[:, :, 0].copy()
188
                dst = numpy.zeros(shape=self.input_shape[1:3])
189
                bob.ip.base.scale(copy, dst)
190
                dst = numpy.reshape(dst, self.input_shape[1:4])
191
            else:
192
                #dst = numpy.resize(data, self.bob_shape) # Scaling with numpy, because bob is c,w,d instead of w,h,c
193
194
195
196
                dst = numpy.zeros(shape=(data.shape[0], data.shape[1], 3))
                dst[:, :, 0] = data[:, :, 0]
                dst[:, :, 1] = data[:, :, 0]
                dst[:, :, 2] = data[:, :, 0]
197
198

                # TODO: LAME SOLUTION
199
200
201
202
203
204
205
                #if data.shape[0] != 3:  # GRAY SCALE IMAGES IN A RGB DATABASE
                #   step_data = numpy.zeros(shape=(3, data.shape[0], data.shape[1]))
                    #step_data = numpy.zeros(shape=(3, data.shape[0], data.shape[1]))
                    #step_data[0, ...] = data[:, :, 0]
                    #step_data[1, ...] = data[:, :, 0]
                    #step_data[2, ...] = data[:, :, 0]
                    #data = step_data
206
207
                #dst = numpy.zeros(shape=(self.bob_shape))
                #bob.ip.base.scale(data, dst)
208
209
210
211

            return dst
        else:
            return data
212

213
214
215
216
217
218
219
    def normalize_sample(self, x):
        """
        Normalize the sample.

        For the time being I'm only scaling from 0-1
        """

220
221
222
223
        if self.normalizer is None:
            return x
        else:
            return self.normalizer(x)
224

225
    def _aggregate_batch(self, data_holder, use_list=False):
226
227
228
229
230
231
232
233
234
        size = len(data_holder[0])
        result = []
        for k in range(size):
            if use_list:
                result.append(
                    [x[k] for x in data_holder])
            else:
                dt = data_holder[0][k]
                if type(dt) in [int, bool]:
235
                    tp = 'int64'
236
                elif type(dt) == float:
237
                    tp = self.input_dtype
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                else:
                    try:
                        tp = dt.dtype
                    except:
                        raise TypeError("Unsupported type to batch: {}".format(type(dt)))
                try:
                    result.append(
                        numpy.asarray([x[k] for x in data_holder], dtype=tp))
                except KeyboardInterrupt:
                    raise
                except:
                    #logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
                    import IPython as IP
                    IP.embed(config=IP.terminal.ipapp.load_default_config())
        return result

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def get_batch(self):
        """
        Shuffle the Memory dataset and get a random batch.

        ** Returns **

        data:
          Selected samples

        labels:
          Correspondent labels
        """

        if self.batch_generator is None:
            self.batch_generator = self._fetch_batch()

        holder = []
        try:
            for i in range(self.batch_size):
273
                data = six.next(self.batch_generator)
274
                
275
276
277
278
279
280
281
                holder.append(data)
                if len(holder) == self.batch_size:
                    return self._aggregate_batch(holder, False)

        except StopIteration:
            self.batch_generator = None
            self.epoch += 1
282
283
284
285
286
287
            
            # If we have left data in the epoch, return
            if len(holder) > 0:
                return self._aggregate_batch(holder, False)
            else:
                self.batch_generator = self._fetch_batch()
288
                data = six.next(self.batch_generator)
289
290
291
                holder.append(data)
                return self._aggregate_batch(holder, False)