SiameseMemory.py 2.51 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy

8 9 10
from .Memory import Memory
from .Siamese import Siamese
import tensorflow as tf
11 12


13
class SiameseMemory(Siamese, Memory):
14 15 16

    def __init__(self, data, labels,
                 input_shape,
17
                 input_dtype="float",
18
                 scale=True,
19
                 batch_size=1,
20 21
                 seed=10,
                 data_augmentation=None):
22 23 24 25 26 27 28 29 30 31 32 33
        """
         Shuffler that deal with memory datasets

         **Parameters**
           data:
           labels:
           perc_train:
           scale:
           train_batch_size:
           validation_batch_size:
        """

34
        super(SiameseMemory, self).__init__(
35 36 37
            data=data,
            labels=labels,
            input_shape=input_shape,
38
            input_dtype=input_dtype,
39
            scale=scale,
40
            batch_size=batch_size,
41 42
            seed=seed,
            data_augmentation=data_augmentation
43
        )
44 45
        # Seting the seed
        numpy.random.seed(seed)
46

47
        self.data = self.data.astype(input_dtype)
48

49
    def get_batch(self, zero_one_labels=True):
50 51 52 53 54 55 56 57
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """
58 59
        sample_l = numpy.zeros(shape=self.shape, dtype='float')
        sample_r = numpy.zeros(shape=self.shape, dtype='float')
60
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float')
61 62

        genuine = True
63
        for i in range(self.shape[0]):
64
            sample_l[i, ...], sample_r[i, ...] = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
65 66 67 68 69 70
            if zero_one_labels:
                labels_siamese[i] = not genuine
            else:
                labels_siamese[i] = -1 if genuine else +1
            genuine = not genuine

71 72
        # Applying the data augmentation
        if self.data_augmentation is not None:
73 74 75
            for i in range(sample_l.shape[0]):
                d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_l[i, ...])))
                sample_l[i, ...] = d
76

77 78
                d = self.bob2skimage(self.data_augmentation(self.skimage2bob(sample_r[i, ...])))
                sample_r[i, ...] = d
79

80 81
        sample_l = self.normalize_sample(sample_l)
        sample_r = self.normalize_sample(sample_r)
82

83
        return [sample_l.astype("float32"), sample_r.astype("float32"), labels_siamese]