SiameseDisk.py 2.42 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
7 8 9
import bob.core
logger = bob.core.log.setup("bob.learn.tensorflow")

10 11
from .Disk import Disk
from .Siamese import Siamese
12 13


14
class SiameseDisk(Siamese, Disk):
15 16 17 18
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
19
                 batch_size=1,
20 21
                 seed=10,
                 data_augmentation=None):
22 23 24 25
        """
         Shuffler that deal with file list

         **Parameters**
26 27 28 29 30 31
         data:
         labels:
         input_shape: Shape of the input. `input_shape != data.shape`, the data will be reshaped
         input_dtype="float64":
         scale=True:
         batch_size=1:
32 33
        """

34 35 36 37 38 39
        if isinstance(data, list):
            data = numpy.array(data)

        if isinstance(labels, list):
            labels = numpy.array(labels)

40
        super(SiameseDisk, self).__init__(
41 42 43
            data=data,
            labels=labels,
            input_shape=input_shape,
44
            input_dtype=input_dtype,
45
            scale=scale,
46
            batch_size=batch_size,
47 48
            seed=seed,
            data_augmentation=data_augmentation
49
        )
50 51
        # Seting the seed
        numpy.random.seed(seed)
52

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
53 54 55
        # TODO: very bad solution to deal with bob.shape images an tf shape images
        self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))

56
    def get_batch(self):
57 58 59 60 61 62 63 64 65
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

66 67
        sample_l = numpy.zeros(shape=self.shape, dtype='float32')
        sample_r = numpy.zeros(shape=self.shape, dtype='float32')
68
        labels_siamese = numpy.zeros(shape=self.shape[0], dtype='float32')
69 70

        genuine = True
71 72
        for i in range(self.shape[0]):
            file_name, file_name_p = self.get_genuine_or_not(self.data, self.labels, genuine=genuine)
73 74
            sample_l[i, ...] = self.load_from_file(str(file_name))
            sample_r[i, ...] = self.load_from_file(str(file_name_p))
75

76
            labels_siamese[i] = not genuine
77 78
            genuine = not genuine

79 80
        sample_l = self.normalize_sample(sample_l)
        sample_r = self.normalize_sample(sample_r)
81

82
        return sample_l, sample_r, labels_siamese