SiameseMemory.py 2.5 KB
Newer Older
1 2 3 4 5 6 7
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy

8 9 10
from .Memory import Memory
from .Siamese import Siamese
import tensorflow as tf
11 12


13
class SiameseMemory(Siamese, Memory):
14 15 16

    def __init__(self, data, labels,
                 input_shape,
17
                 input_dtype="float",
18
                 scale=True,
19
                 batch_size=1,
20 21
                 seed=10,
                 data_augmentation=None):
22 23 24 25 26 27 28 29 30 31 32 33
        """
         Shuffler that deal with memory datasets

         **Parameters**
           data:
           labels:
           perc_train:
           scale:
           train_batch_size:
           validation_batch_size:
        """

34
        super(SiameseMemory, self).__init__(
35 36 37
            data=data,
            labels=labels,
            input_shape=input_shape,
38
            input_dtype=input_dtype,
39
            scale=scale,
40
            batch_size=batch_size,
41 42
            seed=seed,
            data_augmentation=data_augmentation
43
        )
44 45
        # Seting the seed
        numpy.random.seed(seed)
46

47
        self.data = self.data.astype(input_dtype)
48

49
    def get_batch(self, zero_one_labels=True):
50 51 52 53 54 55 56 57
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """
58 59 60
        data = numpy.zeros(shape=self.shape, dtype='float')
        data_p = numpy.zeros(shape=self.shape, dtype='float')
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float')
61 62

        genuine = True
63 64
        for i in range(self.shape[0]):
            data[i, ...], data_p[i, ...] = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
65 66 67 68 69 70
            if zero_one_labels:
                labels_siamese[i] = not genuine
            else:
                labels_siamese[i] = -1 if genuine else +1
            genuine = not genuine

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        # Applying the data augmentation
        if self.data_augmentation is not None:
            for i in range(data.shape[0]):
                d = self.bob2skimage(self.data_augmentation(self.skimage2bob(data[i, ...])))
                data[i, ...] = d

                d = self.bob2skimage(self.data_augmentation(self.skimage2bob(data_p[i, ...])))
                data_p[i, ...] = d

        # Scaling
        if self.scale:
            data *= self.scale_value
            data_p *= self.scale_value

        return [data.astype("float32"), data_p.astype("float32"), labels_siamese]