bio.py 4.18 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1
2
import six
import tensorflow as tf
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
3
from bob.bio.base import read_original_data
4
import logging
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
5

6
logger = logging.getLogger(__name__)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
7

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
8

9
10
11
12
13
class BioGenerator(object):
    """A generator class which wraps bob.bio.base databases so that they can
    be used with tf.data.Dataset.from_generator

    Attributes
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
14
    ----------
15
16
17
    biofile_to_label : :obj:`object`, optional
        A callable with the signature of ``label = biofile_to_label(biofile)``.
        By default -1 is returned as label.
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
18
19
    biofiles : [:any:`bob.bio.base.database.BioFile`]
        The list of the bio files .
20
21
22
23
24
25
26
27
    database : :any:`bob.bio.base.database.BioDatabase`
        The database that you want to use.
    epoch : int
        The number of epochs that have been passed so far.
    keys : [str]
        The keys of samples obtained by calling ``biofile.make_path("", "")``
    labels : [int]
        The labels obtained by calling ``label = biofile_to_label(biofile)``
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
28
    load_data : :obj:`object`, optional
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
29
        A callable with the signature of
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
30
        ``data = load_data(database, biofile)``.
31
32
        :any:`bob.bio.base.read_original_data` is wrapped to be used by
        default.
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
33
34
    multiple_samples : bool, optional
        If true, it assumes that the bio database's samples actually contain
35
36
37
38
39
        multiple samples. This is useful for when you want to for example treat
        video databases as image databases.
    repeat : :obj:`int`, optional
        The samples are repeated ``repeat`` times. ``-1`` will make it repeat
        forever.
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
40
41
42
43
44
    output_types : (object, object, object)
        The types of the returned samples.
    output_shapes : (tf.TensorShape, tf.TensorShape, tf.TensorShape)
        The shapes of the returned samples.
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def __init__(self, database, biofiles, load_data=None,
                 biofile_to_label=None, multiple_samples=False, repeat=1):
        if load_data is None:
            def load_data(database, biofile):
                data = read_original_data(
                    biofile,
                    database.original_directory,
                    database.original_extension)
                return data
        if biofile_to_label is None:
            def biofile_to_label(biofile):
                return -1

        self.labels = (biofile_to_label(f) for f in biofiles)
        self.keys = (str(f.make_path("", "")) for f in biofiles)
        self.database = database
        self.biofiles = biofiles
        self.load_data = load_data
        self.biofile_to_label = biofile_to_label
        self.multiple_samples = multiple_samples
        self.repeat = repeat
        self.epoch = 0

        # load one data to get its type and shape
        data = load_data(database, biofiles[0])
        if multiple_samples:
            try:
                data = data[0]
            except TypeError:
                # if the data is a generator
                data = six.next(data)
        data = tf.convert_to_tensor(data)
        self._output_types = (data.dtype, tf.int64, tf.string)
        self._output_shapes = (
            data.shape, tf.TensorShape([]), tf.TensorShape([]))

        logger.debug("Initializing a dataset with %d files and %s types "
                     "and %s shapes", len(self.biofiles), self.output_types,
                     self.output_shapes)

    @property
    def output_types(self):
        return self._output_types

    @property
    def output_shapes(self):
        return self._output_shapes

    def __call__(self):
        """A generator function that when called will return the samples.

        Yields
        ------
        (data, label, key) : tuple
            A tuple containing the data, label, and the key.
        """
102
        while True:
103
104
105
            for f, label, key in six.moves.zip(
                    self.biofiles, self.labels, self.keys):
                data = self.load_data(self.database, f)
106
                # labels
107
                if self.multiple_samples:
108
109
110
111
                    for d in data:
                        yield (d, label, key)
                else:
                    yield (data, label, key)
112
113
114
            self.epoch += 1
            logger.info("Elapsed %d epochs", self.epoch)
            if self.repeat != -1 and self.epoch >= self.repeat:
115
                break