BaseDataShuffler.py 3.04 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import tensorflow as tf


class BaseDataShuffler(object):
11 12 13 14
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
15
                 batch_size=1):
16 17 18 19 20 21 22 23 24 25 26 27 28 29
        """
         The class provide base functionoalies to shuffle the data

         **Parameters**
           data:
           labels:
           perc_train:
           scale:
           train_batch_size:
           validation_batch_size:
        """

        self.scale = scale
        self.scale_value = 0.00390625
30
        self.input_dtype = input_dtype
31 32

        # TODO: Check if the bacth size is higher than the input data
33
        self.batch_size = batch_size
34 35

        self.data = data
36
        self.shape = tuple([batch_size] + input_shape)
37 38

        self.labels = labels
39
        self.possible_labels = list(set(self.labels))
40

41
        # Computing the data samples fro train and validation
42 43 44 45 46 47
        self.n_samples = len(self.labels)

        # Shuffling all the indexes
        self.indexes = numpy.array(range(self.n_samples))
        numpy.random.shuffle(self.indexes)

48
    def get_placeholders_forprefetch(self, name=""):
49 50 51
        """
        Returns a place holder with the size of your batch
        """
52
        data = tf.placeholder(tf.float32, shape=tuple([None] + list(self.shape[1:])), name=name)
53 54 55 56
        labels = tf.placeholder(tf.int64, shape=[None, ])

        return data, labels

57
    def get_placeholders(self, name=""):
58 59 60
        """
        Returns a place holder with the size of your batch
        """
61 62
        data = tf.placeholder(tf.float32, shape=self.shape, name=name)
        labels = tf.placeholder(tf.int64, shape=self.shape[0])
63 64

        return data, labels
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

    def get_genuine_or_not(self, input_data, input_labels, genuine=True):
        if genuine:
            # Getting a client
            index = numpy.random.randint(len(self.possible_labels))
            index = self.possible_labels[index]

            # Getting the indexes of the data from a particular client
            indexes = numpy.where(input_labels == index)[0]
            numpy.random.shuffle(indexes)

            # Picking a pair
            data = input_data[indexes[0], ...]
            data_p = input_data[indexes[1], ...]

        else:
            # Picking a pair of labels from different clients
            index = numpy.random.choice(len(self.possible_labels), 2, replace=False)
            index[0] = self.possible_labels[index[0]]
            index[1] = self.possible_labels[index[1]]

            # Getting the indexes of the two clients
            indexes = numpy.where(input_labels == index[0])[0]
            indexes_p = numpy.where(input_labels == index[1])[0]
            numpy.random.shuffle(indexes)
            numpy.random.shuffle(indexes_p)

            # Picking a pair
            data = input_data[indexes[0], ...]
            data_p = input_data[indexes_p[0], ...]

        return data, data_p