Base.py 6.5 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import tensorflow as tf
import bob.ip.base
9
import numpy
10
from bob.learn.tensorflow.datashuffler.Normalizer import Linear
11
12
13


class Base(object):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
14
15
16
17
    """
     The class provide base functionalities to shuffle the data to train a neural network

     **Parameters**
18

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
     data:
       Input data to be trainer

     labels:
       Labels. These labels should be set from 0..1

     input_shape:
       The shape of the inputs

     input_dtype:
       The type of the data,

     batch_size:
       Batch size

     seed:
       The seed of the random number generator

     data_augmentation:
       The algorithm used for data augmentation. Look :py:class:`bob.learn.tensorflow.datashuffler.DataAugmentation`

     normalizer:
       The algorithm used for feature scaling. Look :py:class:`bob.learn.tensorflow.datashuffler.ScaleFactor`, :py:class:`bob.learn.tensorflow.datashuffler.Linear` and :py:class:`bob.learn.tensorflow.datashuffler.MeanOffset`
Tiago Pereira's avatar
Tiago Pereira committed
42
43
44
45
46
47
       
     prefetch:
        Do prefetch?
        
     prefetch_capacity:
        
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
48

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49
50
    """

51
    def __init__(self, data, labels,
Tiago Pereira's avatar
Tiago Pereira committed
52
                 input_shape=[None, 28, 28, 1],
53
                 input_dtype="float64",
Tiago Pereira's avatar
Tiago Pereira committed
54
                 batch_size=32,
55
                 seed=10,
56
                 data_augmentation=None,
Tiago Pereira's avatar
Tiago Pereira committed
57
58
59
60
61
                 normalizer=Linear(),
                 prefetch=False,
                 prefetch_capacity=10):

        # Setting the seed for the pseudo random number generator
62
63
        self.seed = seed
        numpy.random.seed(seed)
64

65
        self.normalizer = normalizer
66
67
68
69
70
        self.input_dtype = input_dtype

        # TODO: Check if the bacth size is higher than the input data
        self.batch_size = batch_size

Tiago Pereira's avatar
Tiago Pereira committed
71
        # Preparing the inputs
72
73
74
75
76
77
78
79
80
81
82
83
        self.data = data
        self.input_shape = tuple(input_shape)
        self.labels = labels
        self.possible_labels = list(set(self.labels))

        # Computing the data samples fro train and validation
        self.n_samples = len(self.labels)

        # Shuffling all the indexes
        self.indexes = numpy.array(range(self.n_samples))
        numpy.random.shuffle(self.indexes)

Tiago Pereira's avatar
Tiago Pereira committed
84
        # Use data data augmentation?
85
86
        self.data_augmentation = data_augmentation

Tiago Pereira's avatar
Tiago Pereira committed
87
88
89
90
91
92
93
        # Preparing placeholders
        self.data_ph = None
        self.label_ph = None
        # Prefetch variables
        self.prefetch = prefetch
        self.data_ph_from_queue = None
        self.label_ph_from_queue = None
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
94

Tiago Pereira's avatar
Tiago Pereira committed
95
    def create_placeholders(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
96
        """
Tiago Pereira's avatar
Tiago Pereira committed
97
98
99
        Create place holder instances
        
        :return: 
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
100
        """
Tiago Pereira's avatar
Tiago Pereira committed
101
102
103
104
105
106
107
108
109
110
        with tf.name_scope("Input"):

            self.data_ph = tf.placeholder(tf.float32, shape=self.input_shape, name="data")
            self.label_ph = tf.placeholder(tf.int64, shape=[None], name="label")

            # If prefetch, setup the queue to feed data
            if self.prefetch:
                queue = tf.FIFOQueue(capacity=self.prefetch_capacity,
                                     dtypes=[tf.float32, tf.int64],
                                     shapes=[self.input_shape[1:], []])
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
111

Tiago Pereira's avatar
Tiago Pereira committed
112
113
114
115
116
117
118
119
120
                # Fetching the place holders from the queue
                self.enqueue_op = queue.enqueue_many([self.data_ph, self.label_ph])
                self.data_ph_from_queue, self.label_ph_from_queue = queue.dequeue_many(self.batch_size)

            else:
                self.data_ph_from_queue = self.data_ph
                self.label_ph_from_queue = self.label_ph

    def __call__(self, element, from_queue=False):
121
        """
Tiago Pereira's avatar
Tiago Pereira committed
122
123
        Return the necessary placeholder
        
124
125
        """

Tiago Pereira's avatar
Tiago Pereira committed
126
127
128
129
130
131
        if not element in ["data", "label"]:
            raise ValueError("Value '{0}' invalid. Options available are {1}".format(element, self.placeholder_options))

        # If None, create the placeholders from scratch
        if self.data_ph is None:
            self.create_placeholders()
132

Tiago Pereira's avatar
Tiago Pereira committed
133
134
135
136
137
        if element == "data":
            if from_queue:
                return self.data_ph_from_queue
            else:
                return self.data_ph
138

Tiago Pereira's avatar
Tiago Pereira committed
139
140
141
142
143
        else:
            if from_queue:
                return self.label_ph_from_queue
            else:
                return self.label_ph
144

Tiago Pereira's avatar
Tiago Pereira committed
145
    def get_batch(self):
146
        """
Tiago Pereira's avatar
Tiago Pereira committed
147
        Shuffle dataset and get a random batch.
148
        """
Tiago Pereira's avatar
Tiago Pereira committed
149
        raise NotImplementedError("Method not implemented in this level. You should use one of the derived classes.")
150
151
152
153
154
155

    def bob2skimage(self, bob_image):
        """
        Convert bob color image to the skcit image
        """

156
        skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
157

158
        for i in range(bob_image.shape[0]):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
159
            skimage[:, :, i] = bob_image[i, :, :]
160
161
162

        return skimage

163
164
165
166
167
168
169
170
171
172
173
174
    def skimage2bob(self, sk_image):
        """
        Convert bob color image to the skcit image
        """

        bob_image = numpy.zeros(shape=(sk_image.shape[2], sk_image.shape[0], sk_image.shape[1]))

        for i in range(bob_image.shape[0]):
            bob_image[i, :, :] = sk_image[:, :, i]  # Copying red

        return bob_image

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def rescale(self, data):
        """
        Reescale a single sample with input_shape

        """
        # if self.input_shape != data.shape:
        if self.bob_shape != data.shape:

            # TODO: Implement a better way to do this reescaling
            # If it is gray scale
            if self.input_shape[2] == 1:
                copy = data[:, :, 0].copy()
                dst = numpy.zeros(shape=self.input_shape[0:2])
                bob.ip.base.scale(copy, dst)
                dst = numpy.reshape(dst, self.input_shape)
            else:
                # dst = numpy.resize(data, self.bob_shape) # Scaling with numpy, because bob is c,w,d instead of w,h,c
                dst = numpy.zeros(shape=self.bob_shape)

                # TODO: LAME SOLUTION
                if data.shape[0] != 3:  # GRAY SCALE IMAGES IN A RGB DATABASE
                    step_data = numpy.zeros(shape=(3, data.shape[0], data.shape[1]))
                    step_data[0, ...] = data[:, :]
                    step_data[1, ...] = data[:, :]
                    step_data[2, ...] = data[:, :]
                    data = step_data

                bob.ip.base.scale(data, dst)

            return dst
        else:
            return data
207

208
209
210
211
212
213
214
    def normalize_sample(self, x):
        """
        Normalize the sample.

        For the time being I'm only scaling from 0-1
        """

215
        return self.normalizer(x)