TripletDisk.py 6.27 KB
Newer Older
1 2 3 4 5 6 7 8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import bob.io.base
import bob.io.image
9
import bob.ip.base
10 11 12
import bob.core
logger = bob.core.log.setup("bob.learn.tensorflow")

13 14 15 16 17 18 19 20 21 22 23 24
import tensorflow as tf

from .BaseDataShuffler import BaseDataShuffler

#def scale_mean_norm(data, scale=0.00390625):
#    mean = numpy.mean(data)
#    data = (data - mean) * scale

#    return data, mean


class TextDataShuffler(BaseDataShuffler):
25 26 27 28
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
29
                 batch_size=1):
30 31 32 33
        """
         Shuffler that deal with file list

         **Parameters**
34 35 36 37 38 39
         data:
         labels:
         input_shape: Shape of the input. `input_shape != data.shape`, the data will be reshaped
         input_dtype="float64":
         scale=True:
         batch_size=1:
40 41
        """

42 43 44 45 46 47
        if isinstance(data, list):
            data = numpy.array(data)

        if isinstance(labels, list):
            labels = numpy.array(labels)

48 49 50 51
        super(TextDataShuffler, self).__init__(
            data=data,
            labels=labels,
            input_shape=input_shape,
52
            input_dtype=input_dtype,
53
            scale=scale,
54
            batch_size=batch_size
55 56
        )

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57 58 59
        # TODO: very bad solution to deal with bob.shape images an tf shape images
        self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))

60 61
    def load_from_file(self, file_name, shape):
        d = bob.io.base.load(file_name)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
62
        if d.shape[0] != 3 and self.input_shape[2] != 3: # GRAY SCALE IMAGE
63
            data = numpy.zeros(shape=(d.shape[0], d.shape[1], 1))
64
            data[:, :, 0] = d
65
            data = self.rescale(data)
66
        else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
67
            d = self.rescale(d)
68
            data = self.bob2skimage(d)
69

70 71 72 73
        # Checking NaN
        if numpy.sum(numpy.isnan(data)) > 0:
            logger.warning("######### Image {0} has noise #########".format(file_name))

74
        return data
75

76 77 78 79 80 81 82
    def bob2skimage(self, bob_image):
        """
        Convert bob color image to the skcit image
        """

        skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], 3))

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83 84 85
        skimage[:, :, 0] = bob_image[0, :, :] #Copying red
        skimage[:, :, 1] = bob_image[1, :, :] #Copying green
        skimage[:, :, 2] = bob_image[2, :, :] #Copying blue
86 87 88

        return skimage

89
    def get_batch(self):
90 91

        # Shuffling samples
92
        indexes = numpy.array(range(self.data.shape[0]))
93 94
        numpy.random.shuffle(indexes)

95 96
        selected_data = numpy.zeros(shape=self.shape)
        for i in range(self.batch_size):
97

98 99
            file_name = self.data[indexes[i]]
            data = self.load_from_file(file_name, self.shape)
100 101

            selected_data[i, ...] = data
102 103
            if self.scale:
                selected_data[i, ...] *= self.scale_value
104

105
        selected_labels = self.labels[indexes[0:self.batch_size]]
106 107

        return selected_data.astype("float32"), selected_labels
108

109 110 111 112 113
    def rescale(self, data):
        """
        Reescale a single sample with input_shape

        """
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
114 115 116
        #if self.input_shape != data.shape:
        if self.bob_shape != data.shape:

117 118 119 120 121 122 123 124
            # TODO: Implement a better way to do this reescaling
            # If it is gray scale
            if self.input_shape[2] == 1:
                copy = data[:, :, 0].copy()
                dst = numpy.zeros(shape=self.input_shape[0:2])
                bob.ip.base.scale(copy, dst)
                dst = numpy.reshape(dst, self.input_shape)
            else:
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
125 126 127 128 129 130 131 132 133 134 135 136
                #dst = numpy.resize(data, self.bob_shape) # Scaling with numpy, because bob is c,w,d instead of w,h,c
                dst = numpy.zeros(shape=self.bob_shape)

                # TODO: LAME SOLUTION
                if data.shape[0] != 3: # GRAY SCALE IMAGES IN A RGB DATABASE
                    step_data = numpy.zeros(shape=(3, data.shape[0], data.shape[1]))
                    step_data[0, ...] = data[:, :]
                    step_data[1, ...] = data[:, :]
                    step_data[2, ...] = data[:, :]
                    data = step_data

                bob.ip.base.scale(data, dst)
137 138 139 140 141

            return dst
        else:
            return data

142
    def get_pair(self, zero_one_labels=True):
143 144 145 146 147 148 149 150 151
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

152 153 154
        data = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
155 156

        genuine = True
157 158 159 160
        for i in range(self.shape[0]):
            file_name, file_name_p = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
            data[i, ...] = self.load_from_file(str(file_name), self.shape)
            data_p[i, ...] = self.load_from_file(str(file_name_p), self.shape)
161 162 163 164 165 166 167 168 169 170 171 172 173

            if zero_one_labels:
                labels_siamese[i] = not genuine
            else:
                labels_siamese[i] = -1 if genuine else +1
            genuine = not genuine

        if self.scale:
            data *= self.scale_value
            data_p *= self.scale_value

        return data, data_p, labels_siamese

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    def get_random_triplet(self):
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

        data_a = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        data_n = numpy.zeros(shape=self.shape, dtype='float32')

        for i in range(self.shape[0]):
            file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
            data_a[i, ...] = self.load_from_file(str(file_name_a), self.shape)
            data_p[i, ...] = self.load_from_file(str(file_name_p), self.shape)
            data_n[i, ...] = self.load_from_file(str(file_name_n), self.shape)

        if self.scale:
            data_a *= self.scale_value
            data_p *= self.scale_value
            data_n *= self.scale_value

        return data_a, data_p, data_n