SiameseMemory.py 1.84 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy

8 9 10
from .Memory import Memory
from .Siamese import Siamese
import tensorflow as tf
11 12


13
class SiameseMemory(Siamese, Memory):
14 15 16 17 18

    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
19
                 batch_size=1):
20 21 22 23 24 25 26 27 28 29 30 31
        """
         Shuffler that deal with memory datasets

         **Parameters**
           data:
           labels:
           perc_train:
           scale:
           train_batch_size:
           validation_batch_size:
        """

32
        super(SiameseMemory, self).__init__(
33 34 35
            data=data,
            labels=labels,
            input_shape=input_shape,
36
            input_dtype=input_dtype,
37
            scale=scale,
38
            batch_size=batch_size
39 40
        )

41
        self.data = self.data.astype(input_dtype)
42
        if self.scale:
43
            self.data *= self.scale_value
44

45
    def get_batch(self, zero_one_labels=True):
46 47 48 49 50 51 52 53
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """
54 55 56
        data = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
57 58

        genuine = True
59 60
        for i in range(self.shape[0]):
            data[i, ...], data_p[i, ...] = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
61 62 63 64 65 66
            if zero_one_labels:
                labels_siamese[i] = not genuine
            else:
                labels_siamese[i] = -1 if genuine else +1
            genuine = not genuine

67
        return [data, data_p, labels_siamese]