TripletDisk.py 2.4 KB
Newer Older
1 2 3 4 5 6 7 8
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import bob.io.base
import bob.io.image
9
import bob.ip.base
10 11 12
import bob.core
logger = bob.core.log.setup("bob.learn.tensorflow")

13 14
import tensorflow as tf

15 16
from .Disk import Disk
from .Triplet import Triplet
17 18


19
class TripletDisk(Triplet, Disk):
20 21 22 23
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
24 25
                 batch_size=1,
                 seed=10):
26 27 28 29
        """
         Shuffler that deal with file list

         **Parameters**
30 31 32 33 34 35
         data:
         labels:
         input_shape: Shape of the input. `input_shape != data.shape`, the data will be reshaped
         input_dtype="float64":
         scale=True:
         batch_size=1:
36 37
        """

38 39 40 41 42 43
        if isinstance(data, list):
            data = numpy.array(data)

        if isinstance(labels, list):
            labels = numpy.array(labels)

44
        super(TripletDisk, self).__init__(
45 46 47
            data=data,
            labels=labels,
            input_shape=input_shape,
48
            input_dtype=input_dtype,
49
            scale=scale,
50
            batch_size=batch_size
51
        )
52 53
        # Seting the seed
        numpy.random.seed(seed)
54

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
55 56 57
        # TODO: very bad solution to deal with bob.shape images an tf shape images
        self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))

58
    def get_batch(self):
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        """
        Get a random pair of samples

        **Parameters**
            is_target_set_train: Defining the target set to get the batch

        **Return**
        """

        data_a = numpy.zeros(shape=self.shape, dtype='float32')
        data_p = numpy.zeros(shape=self.shape, dtype='float32')
        data_n = numpy.zeros(shape=self.shape, dtype='float32')

        for i in range(self.shape[0]):
            file_name_a, file_name_p, file_name_n = self.get_one_triplet(self.data, self.labels)
            data_a[i, ...] = self.load_from_file(str(file_name_a), self.shape)
            data_p[i, ...] = self.load_from_file(str(file_name_p), self.shape)
            data_n[i, ...] = self.load_from_file(str(file_name_n), self.shape)

        if self.scale:
            data_a *= self.scale_value
            data_p *= self.scale_value
            data_n *= self.scale_value

83
        return [data_a, data_p, data_n]