SiameseMemory.py 1.95 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy

8 9 10
from .Memory import Memory
from .Siamese import Siamese
import tensorflow as tf
11 12


13
class SiameseMemory(Siamese, Memory):
14 15 16 17 18

    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
19 20
                 batch_size=1,
                 seed=10):
21 22 23 24 25 26 27 28 29 30 31 32
        """
         Shuffler that deal with memory datasets

         **Parameters**
           data:
           labels:
           perc_train:
           scale:
           train_batch_size:
           validation_batch_size:
        """

33
        super(SiameseMemory, self).__init__(
34 35 36
            data=data,
            labels=labels,
            input_shape=input_shape,
37
            input_dtype=input_dtype,
38
            scale=scale,
39 40
            batch_size=batch_size,
            seed=seed
41
        )
42 43
        # Seting the seed
        numpy.random.seed(seed)
44

45
        self.data = self.data.astype(input_dtype)
46
        if self.scale:
47
            self.data *= self.scale_value
48

49
    def get_batch(self, zero_one_labels=True):
50 51 52 53 54 55 56 57
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """
58 59 60
        data = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
61 62

        genuine = True
63 64
        for i in range(self.shape[0]):
            data[i, ...], data_p[i, ...] = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
65 66 67 68 69 70
            if zero_one_labels:
                labels_siamese[i] = not genuine
            else:
                labels_siamese[i] = -1 if genuine else +1
            genuine = not genuine

71
        return [data, data_p, labels_siamese]