BaseDataShuffler.py 3.88 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import tensorflow as tf


class BaseDataShuffler(object):
11 12 13 14
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
15
                 batch_size=1):
16 17 18 19 20 21 22 23 24 25 26 27 28 29
        """
         The class provide base functionoalies to shuffle the data

         **Parameters**
           data:
           labels:
           perc_train:
           scale:
           train_batch_size:
           validation_batch_size:
        """

        self.scale = scale
        self.scale_value = 0.00390625
30
        self.input_dtype = input_dtype
31 32

        # TODO: Check if the bacth size is higher than the input data
33
        self.batch_size = batch_size
34 35

        self.data = data
36
        self.shape = tuple([batch_size] + input_shape)
37 38
        self.input_shape = tuple(input_shape)

39 40

        self.labels = labels
41
        self.possible_labels = list(set(self.labels))
42

43
        # Computing the data samples fro train and validation
44 45 46 47 48 49
        self.n_samples = len(self.labels)

        # Shuffling all the indexes
        self.indexes = numpy.array(range(self.n_samples))
        numpy.random.shuffle(self.indexes)

50
    def get_placeholders_forprefetch(self, name=""):
51 52 53
        """
        Returns a place holder with the size of your batch
        """
54
        data = tf.placeholder(tf.float32, shape=tuple([None] + list(self.shape[1:])), name=name)
55 56 57 58
        labels = tf.placeholder(tf.int64, shape=[None, ])

        return data, labels

59
    def get_placeholders(self, name=""):
60 61 62
        """
        Returns a place holder with the size of your batch
        """
63 64
        data = tf.placeholder(tf.float32, shape=self.shape, name=name)
        labels = tf.placeholder(tf.int64, shape=self.shape[0])
65 66

        return data, labels
67 68

    def get_genuine_or_not(self, input_data, input_labels, genuine=True):
69

70 71 72
        if genuine:
            # Getting a client
            index = numpy.random.randint(len(self.possible_labels))
73
            index = int(self.possible_labels[index])
74 75 76 77 78 79 80 81 82 83 84 85

            # Getting the indexes of the data from a particular client
            indexes = numpy.where(input_labels == index)[0]
            numpy.random.shuffle(indexes)

            # Picking a pair
            data = input_data[indexes[0], ...]
            data_p = input_data[indexes[1], ...]

        else:
            # Picking a pair of labels from different clients
            index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
86 87
            index[0] = self.possible_labels[int(index[0])]
            index[1] = self.possible_labels[int(index[1])]
88 89 90 91 92 93 94 95 96 97 98 99

            # Getting the indexes of the two clients
            indexes = numpy.where(input_labels == index[0])[0]
            indexes_p = numpy.where(input_labels == index[1])[0]
            numpy.random.shuffle(indexes)
            numpy.random.shuffle(indexes_p)

            # Picking a pair
            data = input_data[indexes[0], ...]
            data_p = input_data[indexes_p[0], ...]

        return data, data_p
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

    def get_one_triplet(self, input_data, input_labels):
        # Getting a pair of clients
        index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
        label_positive = index[0]
        label_negative = index[1]

        # Getting the indexes of the data from a particular client
        indexes = numpy.where(input_labels == index[0])[0]
        numpy.random.shuffle(indexes)

        # Picking a positive pair
        data_anchor = input_data[indexes[0], ...]
        data_positive = input_data[indexes[1], ...]

        # Picking a negative sample
        indexes = numpy.where(input_labels == index[1])[0]
        numpy.random.shuffle(indexes)
        data_negative = input_data[indexes[0], ...]

        return data_anchor, data_positive, data_negative