Commit aece6e1b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

improve the generator and biogenerator classes

parent f6e1bb57
import six
import tensorflow as tf
from bob.bio.base import read_original_data from bob.bio.base import read_original_data
from .generator import Generator
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BioGenerator(object): class BioGenerator(Generator):
"""A generator class which wraps bob.bio.base databases so that they can """A generator class which wraps bob.bio.base databases so that they can
be used with tf.data.Dataset.from_generator be used with tf.data.Dataset.from_generator
...@@ -15,44 +14,37 @@ class BioGenerator(object): ...@@ -15,44 +14,37 @@ class BioGenerator(object):
biofile_to_label : :obj:`object`, optional biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``. A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label. By default -1 is returned as label.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
database : :any:`bob.bio.base.database.BioDatabase` database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use. The database that you want to use.
epoch : int
The number of epochs that have been passed so far.
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
load_data : :obj:`object`, optional load_data : :obj:`object`, optional
A callable with the signature of A callable with the signature of
``data = load_data(database, biofile)``. ``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is wrapped to be used by :any:`bob.bio.base.read_original_data` is wrapped to be used by
default. default.
multiple_samples : :obj:`bool`, optional biofiles : [:any:`bob.bio.base.database.BioFile`]
If true, it assumes that the bio database's samples actually contain The list of the bio files .
multiple samples. This is useful for when you want to for example treat keys : [str]
video databases as image databases. The keys of samples obtained by calling ``biofile.make_path("", "")``
output_types : (object, object, object) labels : [int]
The types of the returned samples. The labels obtained by calling ``label = biofile_to_label(biofile)``
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples.
""" """
def __init__(self, def __init__(
database, self,
biofiles, database,
load_data=None, biofiles,
biofile_to_label=None, load_data=None,
multiple_samples=False, biofile_to_label=None,
**kwargs): multiple_samples=False,
super(BioGenerator, self).__init__(**kwargs) **kwargs
):
if load_data is None: if load_data is None:
def load_data(database, biofile): def load_data(database, biofile):
data = read_original_data(biofile, database.original_directory, data = read_original_data(
database.original_extension) biofile, database.original_directory, database.original_extension
)
return data return data
if biofile_to_label is None: if biofile_to_label is None:
...@@ -61,29 +53,22 @@ class BioGenerator(object): ...@@ -61,29 +53,22 @@ class BioGenerator(object):
return -1 return -1
self.database = database self.database = database
self.biofiles = list(biofiles)
self.load_data = load_data self.load_data = load_data
self.biofile_to_label = biofile_to_label self.biofile_to_label = biofile_to_label
self.multiple_samples = multiple_samples
self.epoch = 0 def reader(f):
label = int(self.biofile_to_label(f))
# load one data to get its type and shape data = self.load_data(self.database, f)
data = load_data(database, biofiles[0]) key = str(f.make_path("", "")).encode("utf-8")
if multiple_samples: if self.multiple_samples:
try: for d in data:
data = data[0] yield (d, label, key)
except TypeError: else:
# if the data is a generator yield (data, label, key)
data = six.next(data)
data = tf.convert_to_tensor(data) super(BioGenerator, self).__init__(
self._output_types = (data.dtype, tf.int64, tf.string) biofiles, reader, multiple_samples=multiple_samples, **kwargs
self._output_shapes = (data.shape, tf.TensorShape([]), )
tf.TensorShape([]))
logger.info(
"Initializing a dataset with %d files and %s types "
"and %s shapes", len(self.biofiles), self.output_types,
self.output_shapes)
@property @property
def labels(self): def labels(self):
...@@ -93,34 +78,11 @@ class BioGenerator(object): ...@@ -93,34 +78,11 @@ class BioGenerator(object):
@property @property
def keys(self): def keys(self):
for f in self.biofiles: for f in self.biofiles:
yield str(f.make_path("", "")).encode('utf-8') yield str(f.make_path("", "")).encode("utf-8")
@property
def output_types(self):
return self._output_types
@property @property
def output_shapes(self): def biofiles(self):
return self._output_shapes return self.samples
def __len__(self): def __len__(self):
return len(self.biofiles) return len(self.biofiles)
def __call__(self):
"""A generator function that when called will return the samples.
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
for f, label, key in six.moves.zip(self.biofiles, self.labels,
self.keys):
data = self.load_data(self.database, f)
if self.multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch)
import six
import tensorflow as tf import tensorflow as tf
import random import random
import logging import logging
...@@ -39,14 +38,21 @@ class Generator: ...@@ -39,14 +38,21 @@ class Generator:
self.epoch = 0 self.epoch = 0
self.shuffle_on_epoch_end = shuffle_on_epoch_end self.shuffle_on_epoch_end = shuffle_on_epoch_end
# load one data to get its type and shape # load samples until one of them is not empty
dlk = self.reader(self.samples[0]) # this data is used to get the type and shape
if self.multiple_samples: for sample in self.samples:
try: try:
dlk = dlk[0] dlk = self.reader(sample)
except TypeError: if self.multiple_samples:
# if the data is a generator try:
dlk = six.next(dlk) dlk = dlk[0]
except TypeError:
# if the data is a generator
dlk = next(dlk)
except StopIteration:
continue
else:
break
# Creating a "fake" dataset just to get the types and shapes # Creating a "fake" dataset just to get the types and shapes
dataset = tf.data.Dataset.from_tensors(dlk) dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = dataset.output_types self._output_types = dataset.output_types
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment