Base.py 9 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import tensorflow as tf
import bob.ip.base
9
import numpy
10
import six
11
from bob.learn.tensorflow.datashuffler.Normalizer import Linear
12
13
14


class Base(object):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
15
16
17
18
    """
     The class provide base functionalities to shuffle the data to train a neural network

     **Parameters**
19

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
     data:
       Input data to be trainer

     labels:
       Labels. These labels should be set from 0..1

     input_shape:
       The shape of the inputs

     input_dtype:
       The type of the data,

     batch_size:
       Batch size

     seed:
       The seed of the random number generator

     data_augmentation:
       The algorithm used for data augmentation. Look :py:class:`bob.learn.tensorflow.datashuffler.DataAugmentation`

     normalizer:
       The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset`
Tiago Pereira's avatar
Tiago Pereira committed
43
44
45
46
47
48
       
     prefetch:
        Do prefetch?
        
     prefetch_capacity:
        
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
50
51
    """

52
    def __init__(self, data, labels,
Tiago Pereira's avatar
Tiago Pereira committed
53
                 input_shape=[None, 28, 28, 1],
Tiago Pereira's avatar
Tiago Pereira committed
54
                 input_dtype="float32",
Tiago Pereira's avatar
Tiago Pereira committed
55
                 batch_size=32,
56
                 seed=10,
57
                 data_augmentation=None,
Tiago Pereira's avatar
Tiago Pereira committed
58
59
                 normalizer=Linear(),
                 prefetch=False,
60
61
                 prefetch_capacity=50,
                 prefetch_threads=5):
Tiago Pereira's avatar
Tiago Pereira committed
62
63

        # Setting the seed for the pseudo random number generator
64
65
        self.seed = seed
        numpy.random.seed(seed)
66

67
        self.normalizer = normalizer
68
69
70
71
72
        self.input_dtype = input_dtype

        # TODO: Check if the bacth size is higher than the input data
        self.batch_size = batch_size

Tiago Pereira's avatar
Tiago Pereira committed
73
        # Preparing the inputs
74
75
76
77
78
79
80
81
82
83
84
85
        self.data = data
        self.input_shape = tuple(input_shape)
        self.labels = labels
        self.possible_labels = list(set(self.labels))

        # Computing the data samples fro train and validation
        self.n_samples = len(self.labels)

        # Shuffling all the indexes
        self.indexes = numpy.array(range(self.n_samples))
        numpy.random.shuffle(self.indexes)

Tiago Pereira's avatar
Tiago Pereira committed
86
        # Use data data augmentation?
87
88
        self.data_augmentation = data_augmentation

Tiago Pereira's avatar
Tiago Pereira committed
89
90
91
92
93
        # Preparing placeholders
        self.data_ph = None
        self.label_ph = None
        # Prefetch variables
        self.prefetch = prefetch
Tiago Pereira's avatar
Tiago Pereira committed
94
        self.prefetch_capacity = prefetch_capacity
95
        self.prefetch_threads = prefetch_threads
Tiago Pereira's avatar
Tiago Pereira committed
96
97
        self.data_ph_from_queue = None
        self.label_ph_from_queue = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
98

99
        self.batch_generator = None
100
101
        self.epoch = 0

Tiago Pereira's avatar
Tiago Pereira committed
102
    def create_placeholders(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
103
        """
Tiago Pereira's avatar
Tiago Pereira committed
104
105
106
        Create place holder instances
        
        :return: 
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
107
        """
Tiago Pereira's avatar
Tiago Pereira committed
108
109
110
111
112
113
114
115
116
117
        with tf.name_scope("Input"):

            self.data_ph = tf.placeholder(tf.float32, shape=self.input_shape, name="data")
            self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label")

            # If prefetch, setup the queue to feed data
            if self.prefetch:
                queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
                                     dtypes=[tf.float32, tf.int64],
                                     shapes=[self.input_shape[1:], []])
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
118

Tiago Pereira's avatar
Tiago Pereira committed
119
120
121
122
123
124
125
126
127
                # Fetching the place holders from the queue
                self.enqueue_op = queue.enqueue_many([self.data_ph, self.label_ph])
                self.data_ph_from_queue, self.label_ph_from_queue = queue.dequeue_many(self.batch_size)

            else:
                self.data_ph_from_queue = self.data_ph
                self.label_ph_from_queue = self.label_ph

    def __call__(self, element, from_queue=False):
128
        """
Tiago Pereira's avatar
Tiago Pereira committed
129
130
        Return the necessary placeholder
        
131
132
        """

Tiago Pereira's avatar
Tiago Pereira committed
133
134
135
136
137
138
        if not element in ["data", "label"]:
            raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options))

        # If None, create the placeholders from scratch
        if self.data_ph is None:
            self.create_placeholders()
139

Tiago Pereira's avatar
Tiago Pereira committed
140
141
142
143
144
        if element == "data":
            if from_queue:
                return self.data_ph_from_queue
            else:
                return self.data_ph
145

Tiago Pereira's avatar
Tiago Pereira committed
146
147
148
149
150
        else:
            if from_queue:
                return self.label_ph_from_queue
            else:
                return self.label_ph
151
152
153
154
155
156
157


    def bob2skimage(self, bob_image):
        """
        Convert bob color image to the skcit image
        """

158
        skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
159

160
        for i in range(bob_image.shape[0]):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
161
            skimage[:, :, i] = bob_image[i, :, :]
162
163
164

        return skimage

165
166
167
168
169
170
171
172
173
174
175
176
    def skimage2bob(self, sk_image):
        """
        Convert bob color image to the skcit image
        """

        bob_image = numpy.zeros(shape=(sk_image.shape[2], sk_image.shape[0], sk_image.shape[1]))

        for i in range(bob_image.shape[0]):
            bob_image[i, :, :] = sk_image[:, :, i]  # Copying red

        return bob_image

177
178
179
180
181
182
183
184
185
186
    def rescale(self, data):
        """
        Reescale a single sample with input_shape

        """
        # if self.input_shape != data.shape:
        if self.bob_shape != data.shape:

            # TODO: Implement a better way to do this reescaling
            # If it is gray scale
187
            if self.input_shape[3] == 1:
188
                copy = data[:, :, 0].copy()
189
                dst = numpy.zeros(shape=self.input_shape[1:3])
190
                bob.ip.base.scale(copy, dst)
191
                dst = numpy.reshape(dst, self.input_shape[1:4])
192
            else:
193
194
195
196
197
                #dst = numpy.resize(data, self.bob_shape) # Scaling with numpy, because bob is c,w,d instead of w,h,c
                #dst = numpy.zeros(shape=(data.shape[0], data.shape[1], 3))
                #dst[:, :, 0] = data[:, :, 0]
                #dst[:, :, 1] = data[:, :, 0]
                #dst[:, :, 2] = data[:, :, 0]
198
199

                # TODO: LAME SOLUTION
200
201
202
203
204
205
206
207
                #if data.shape[0] != 3:  # GRAY SCALE IMAGES IN A RGB DATABASE
                #   step_data = numpy.zeros(shape=(3, data.shape[0], data.shape[1]))
                    #step_data = numpy.zeros(shape=(3, data.shape[0], data.shape[1]))
                    #step_data[0, ...] = data[:, :, 0]
                    #step_data[1, ...] = data[:, :, 0]
                    #step_data[2, ...] = data[:, :, 0]
                    #data = step_data
                dst = numpy.zeros(shape=(self.bob_shape))
208
209
210
211
212
                bob.ip.base.scale(data, dst)

            return dst
        else:
            return data
213

214
215
216
217
218
219
220
    def normalize_sample(self, x):
        """
        Normalize the sample.

        For the time being I'm only scaling from 0-1
        """

221
        return self.normalizer(x)
222

223
    def _aggregate_batch(self, data_holder, use_list=False):
224
225
226
227
228
229
230
231
232
        size = len(data_holder[0])
        result = []
        for k in range(size):
            if use_list:
                result.append(
                    [x[k] for x in data_holder])
            else:
                dt = data_holder[0][k]
                if type(dt) in [int, bool]:
233
                    tp = 'int64'
234
                elif type(dt) == float:
235
                    tp = self.input_dtype
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                else:
                    try:
                        tp = dt.dtype
                    except:
                        raise TypeError("Unsupported type to batch: {}".format(type(dt)))
                try:
                    result.append(
                        numpy.asarray([x[k] for x in data_holder], dtype=tp))
                except KeyboardInterrupt:
                    raise
                except:
                    #logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
                    import IPython as IP
                    IP.embed(config=IP.terminal.ipapp.load_default_config())
        return result

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    def get_batch(self):
        """
        Shuffle the Memory dataset and get a random batch.

        ** Returns **

        data:
          Selected samples

        labels:
          Correspondent labels
        """

        if self.batch_generator is None:
            self.batch_generator = self._fetch_batch()

        holder = []
        try:
            for i in range(self.batch_size):
271
                data = six.next(self.batch_generator)
272
                
273
274
275
276
277
278
279
                holder.append(data)
                if len(holder) == self.batch_size:
                    return self._aggregate_batch(holder, False)

        except StopIteration:
            self.batch_generator = None
            self.epoch += 1
280
281
282
283
284
285
            
            # If we have left data in the epoch, return
            if len(holder) > 0:
                return self._aggregate_batch(holder, False)
            else:
                self.batch_generator = self._fetch_batch()
286
                data = six.next(self.batch_generator)
287
288
289
                holder.append(data)
                return self._aggregate_batch(holder, False)