MemoryPairDataShuffler.py 5.64 KB
Newer Older
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
7
from .MemoryDataShuffler import MemoryDataShuffler
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
8 9


10 11
class MemoryPairDataShuffler(MemoryDataShuffler):
    def __init__(self, data, labels, input_shape, perc_train=0.9, scale=True, train_batch_size=1, validation_batch_size=300):
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
12
        """
13
         The class provide some functionalities for shuffling data
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
14 15 16 17 18

         **Parameters**
           data:
        """

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
        data = data
        labels = labels
        input_shape = input_shape
        perc_train = perc_train
        scale = scale
        train_batch_size = train_batch_size
        validation_batch_size = validation_batch_size

        super(MemoryPairDataShuffler, self).__init__(data, labels,
                                                     input_shape=input_shape,
                                                     perc_train=perc_train,
                                                     scale=scale,
                                                     train_batch_size=train_batch_size*2,
                                                     validation_batch_size=validation_batch_size)

    def get_pair(self, train_dataset=True, zero_one_labels=True):
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

        def get_genuine_or_not(input_data, input_labels, genuine=True):

            if genuine:
                # TODO: THIS KEY SELECTION NEEDS TO BE MORE EFFICIENT

                # Getting a client
                index = numpy.random.randint(self.total_labels)

                # Getting the indexes of the data from a particular client
                indexes = numpy.where(input_labels == index)[0]
                numpy.random.shuffle(indexes)

                # Picking a pair
57 58
                data = input_data[indexes[0], ...]
                data_p = input_data[indexes[1], ...]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
59 60 61 62 63 64 65 66 67 68 69 70

            else:
                # Picking a pair from different clients
                index = numpy.random.choice(self.total_labels, 2, replace=False)

                # Getting the indexes of the two clients
                indexes = numpy.where(input_labels == index[0])[0]
                indexes_p = numpy.where(input_labels == index[1])[0]
                numpy.random.shuffle(indexes)
                numpy.random.shuffle(indexes_p)

                # Picking a pair
71 72
                data = input_data[indexes[0], ...]
                data_p = input_data[indexes_p[0], ...]
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
73 74 75

            return data, data_p

76
        if train_dataset:
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
77 78
            target_data = self.train_data
            target_labels = self.train_labels
79
            shape = self.train_shape
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
80 81 82
        else:
            target_data = self.validation_data
            target_labels = self.validation_labels
83
            shape = self.validation_shape
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
84

85 86 87
        data = numpy.zeros(shape=shape, dtype='float32')
        data_p = numpy.zeros(shape=shape, dtype='float32')
        labels_siamese = numpy.zeros(shape=shape[0], dtype='float32')
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
88 89

        genuine = True
90 91
        for i in range(shape[0]):
            data[i, ...], data_p[i, ...] = get_genuine_or_not(target_data, target_labels, genuine=genuine)
Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
92 93 94 95 96 97 98 99
            if zero_one_labels:
                labels_siamese[i] = not genuine
            else:
                labels_siamese[i] = -1 if genuine else +1
            genuine = not genuine

        return data, data_p, labels_siamese

100 101


Tiago de Freitas Pereira's avatar
Scratch  
Tiago de Freitas Pereira committed
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    def get_triplet(self, n_labels, n_triplets=1, is_target_set_train=True):
        """
        Get a triplet

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

        def get_one_triplet(input_data, input_labels):

            # Getting a pair of clients
            index = numpy.random.choice(n_labels, 2, replace=False)
            label_positive = index[0]
            label_negative = index[1]

            # Getting the indexes of the data from a particular client
            indexes = numpy.where(input_labels == index[0])[0]
            numpy.random.shuffle(indexes)

            # Picking a positive pair
            data_anchor = input_data[indexes[0], :, :, :]
            data_positive = input_data[indexes[1], :, :, :]

            # Picking a negative sample
            indexes = numpy.where(input_labels == index[1])[0]
            numpy.random.shuffle(indexes)
            data_negative = input_data[indexes[0], :, :, :]

            return data_anchor, data_positive, data_negative, label_positive, label_positive, label_negative

        if is_target_set_train:
            target_data = self.train_data
            target_labels = self.train_labels
        else:
            target_data = self.validation_data
            target_labels = self.validation_labels

        c = target_data.shape[3]
        w = target_data.shape[1]
        h = target_data.shape[2]

        data_a = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
        data_p = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
        data_n = numpy.zeros(shape=(n_triplets, w, h, c), dtype='float32')
        labels_a = numpy.zeros(shape=n_triplets, dtype='float32')
        labels_p = numpy.zeros(shape=n_triplets, dtype='float32')
        labels_n = numpy.zeros(shape=n_triplets, dtype='float32')

        for i in range(n_triplets):
            data_a[i, :, :, :], data_p[i, :, :, :], data_n[i, :, :, :], \
            labels_a[i], labels_p[i], labels_n[i] = \
                get_one_triplet(target_data, target_labels)

        return data_a, data_p, data_n, labels_a, labels_p, labels_n