SiameseDisk.py 2.32 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
7
8
9
import bob.core
logger = bob.core.log.setup("bob.learn.tensorflow")

10
11
from .Disk import Disk
from .Siamese import Siamese
12
13


14
class SiameseDisk(Siamese, Disk):
15
16
17
18
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
19
20
                 batch_size=1,
                 seed=10):
21
22
23
24
        """
         Shuffler that deal with file list

         **Parameters**
25
26
27
28
29
30
         data:
         labels:
         input_shape: Shape of the input. `input_shape != data.shape`, the data will be reshaped
         input_dtype="float64":
         scale=True:
         batch_size=1:
31
32
        """

33
34
35
36
37
38
        if isinstance(data, list):
            data = numpy.array(data)

        if isinstance(labels, list):
            labels = numpy.array(labels)

39
        super(SiameseDisk, self).__init__(
40
41
42
            data=data,
            labels=labels,
            input_shape=input_shape,
43
            input_dtype=input_dtype,
44
            scale=scale,
45
46
            batch_size=batch_size,
            seed=seed
47
        )
48
49
        # Seting the seed
        numpy.random.seed(seed)
50

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
51
52
53
        # TODO: very bad solution to deal with bob.shape images an tf shape images
        self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))

54
    def get_batch(self):
55
56
57
58
59
60
61
62
63
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

64
65
66
        data = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
67
68

        genuine = True
69
70
        for i in range(self.shape[0]):
            file_name, file_name_p = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
71
72
            data[i, ...] = self.load_from_file(str(file_name))
            data_p[i, ...] = self.load_from_file(str(file_name_p))
73

74
            labels_siamese[i] = not genuine
75
76
77
78
79
80
81
            genuine = not genuine

        if self.scale:
            data *= self.scale_value
            data_p *= self.scale_value

        return data, data_p, labels_siamese