bio.py 2.66 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1
from bob.bio.base import read_original_data
2
from .generator import Generator
3
import logging
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
4

5
logger = logging.getLogger(__name__)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
6

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
7

8
class BioGenerator(Generator):
9 10 11 12
    """A generator class which wraps bob.bio.base databases so that they can
    be used with tf.data.Dataset.from_generator

    Attributes
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
13
    ----------
14 15 16 17 18
    biofile_to_label : :obj:`object`, optional
        A callable with the signature of ``label = biofile_to_label(biofile)``.
        By default -1 is returned as label.
    database : :any:`bob.bio.base.database.BioDatabase`
        The database that you want to use.
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
19
    load_data : :obj:`object`, optional
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
20
        A callable with the signature of
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
21
        ``data = load_data(database, biofile)``.
22 23
        :any:`bob.bio.base.read_original_data` is wrapped to be used by
        default.
24 25 26 27 28 29
    biofiles : [:any:`bob.bio.base.database.BioFile`]
        The list of the bio files .
    keys : [str]
        The keys of samples obtained by calling ``biofile.make_path("", "")``
    labels : [int]
        The labels obtained by calling ``label = biofile_to_label(biofile)``
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
30
    """
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
31

32 33 34 35 36 37 38 39 40 41
    def __init__(
        self,
        database,
        biofiles,
        load_data=None,
        biofile_to_label=None,
        multiple_samples=False,
        **kwargs
    ):

42
        if load_data is None:
43

44
            def load_data(database, biofile):
45 46 47
                data = read_original_data(
                    biofile, database.original_directory, database.original_extension
                )
48
                return data
49

50
        if biofile_to_label is None:
51

52 53 54 55 56 57
            def biofile_to_label(biofile):
                return -1

        self.database = database
        self.load_data = load_data
        self.biofile_to_label = biofile_to_label
58

59
        def _reader(f):
60 61 62
            label = int(self.biofile_to_label(f))
            data = self.load_data(self.database, f)
            key = str(f.make_path("", "")).encode("utf-8")
63 64 65 66 67
            return data, label, key

        if multiple_samples:
            def reader(f):
                data, label, key = _reader(f)
68 69
                for d in data:
                    yield (d, label, key)
70 71 72
        else:
            def reader(f):
                return _reader(f)
73 74 75 76

        super(BioGenerator, self).__init__(
            biofiles, reader, multiple_samples=multiple_samples, **kwargs
        )
77

78 79 80
    @property
    def labels(self):
        for f in self.biofiles:
81
            yield int(self.biofile_to_label(f))
82 83 84 85

    @property
    def keys(self):
        for f in self.biofiles:
86
            yield str(f.make_path("", "")).encode("utf-8")
87 88

    @property
89 90
    def biofiles(self):
        return self.samples
91

92 93
    def __len__(self):
        return len(self.biofiles)