TripletDisk.py 2.45 KB
Newer Older
1 2 3 4 5 6 7 8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import bob.io.base
import bob.io.image
9
import bob.ip.base
10 11 12
import bob.core
logger = bob.core.log.setup("bob.learn.tensorflow")

13 14
import tensorflow as tf

15 16
from .Disk import Disk
from .Triplet import Triplet
17 18


19
class TripletDisk(Triplet, Disk):
20 21 22 23
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
24
                 batch_size=1,
25 26
                 seed=10,
                 data_augmentation=None):
27 28 29 30
        """
         Shuffler that deal with file list

         **Parameters**
31 32 33 34 35 36
         data:
         labels:
         input_shape: Shape of the input. `input_shape != data.shape`, the data will be reshaped
         input_dtype="float64":
         scale=True:
         batch_size=1:
37 38
        """

39 40 41 42 43 44
        if isinstance(data, list):
            data = numpy.array(data)

        if isinstance(labels, list):
            labels = numpy.array(labels)

45
        super(TripletDisk, self).__init__(
46 47 48
            data=data,
            labels=labels,
            input_shape=input_shape,
49
            input_dtype=input_dtype,
50
            scale=scale,
51 52
            batch_size=batch_size,
            data_augmentation=data_augmentation
53
        )
54 55
        # Seting the seed
        numpy.random.seed(seed)
56

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
57 58 59
        # TODO: very bad solution to deal with bob.shape images an tf shape images
        self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))

60
    def get_batch(self):
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

        data_a = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        data_n = numpy.zeros(shape=self.shape, dtype='float32')

        for i in range(self.shape[0]):
            file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
76 77 78
            data_a[i, ...] = self.load_from_file(str(file_name_a))
            data_p[i, ...] = self.load_from_file(str(file_name_p))
            data_n[i, ...] = self.load_from_file(str(file_name_n))
79 80 81 82 83 84

        if self.scale:
            data_a *= self.scale_value
            data_p *= self.scale_value
            data_n *= self.scale_value

85
        return [data_a, data_p, data_n]