Skip to content
Snippets Groups Projects
Commit aece6e1b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

improve the generator and biogenerator classes

parent f6e1bb57
No related branches found
No related tags found
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
import six
import tensorflow as tf
from bob.bio.base import read_original_data from bob.bio.base import read_original_data
from .generator import Generator
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BioGenerator(object): class BioGenerator(Generator):
"""A generator class which wraps bob.bio.base databases so that they can """A generator class which wraps bob.bio.base databases so that they can
be used with tf.data.Dataset.from_generator be used with tf.data.Dataset.from_generator
...@@ -15,44 +14,37 @@ class BioGenerator(object): ...@@ -15,44 +14,37 @@ class BioGenerator(object):
biofile_to_label : :obj:`object`, optional biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``. A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label. By default -1 is returned as label.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
database : :any:`bob.bio.base.database.BioDatabase` database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use. The database that you want to use.
epoch : int
The number of epochs that have been passed so far.
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
load_data : :obj:`object`, optional load_data : :obj:`object`, optional
A callable with the signature of A callable with the signature of
``data = load_data(database, biofile)``. ``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is wrapped to be used by :any:`bob.bio.base.read_original_data` is wrapped to be used by
default. default.
multiple_samples : :obj:`bool`, optional biofiles : [:any:`bob.bio.base.database.BioFile`]
If true, it assumes that the bio database's samples actually contain The list of the bio files .
multiple samples. This is useful for when you want to for example treat keys : [str]
video databases as image databases. The keys of samples obtained by calling ``biofile.make_path("", "")``
output_types : (object, object, object) labels : [int]
The types of the returned samples. The labels obtained by calling ``label = biofile_to_label(biofile)``
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples.
""" """
def __init__(self, def __init__(
database, self,
biofiles, database,
load_data=None, biofiles,
biofile_to_label=None, load_data=None,
multiple_samples=False, biofile_to_label=None,
**kwargs): multiple_samples=False,
super(BioGenerator, self).__init__(**kwargs) **kwargs
):
if load_data is None: if load_data is None:
def load_data(database, biofile): def load_data(database, biofile):
data = read_original_data(biofile, database.original_directory, data = read_original_data(
database.original_extension) biofile, database.original_directory, database.original_extension
)
return data return data
if biofile_to_label is None: if biofile_to_label is None:
...@@ -61,29 +53,22 @@ class BioGenerator(object): ...@@ -61,29 +53,22 @@ class BioGenerator(object):
return -1 return -1
self.database = database self.database = database
self.biofiles = list(biofiles)
self.load_data = load_data self.load_data = load_data
self.biofile_to_label = biofile_to_label self.biofile_to_label = biofile_to_label
self.multiple_samples = multiple_samples
self.epoch = 0 def reader(f):
label = int(self.biofile_to_label(f))
# load one data to get its type and shape data = self.load_data(self.database, f)
data = load_data(database, biofiles[0]) key = str(f.make_path("", "")).encode("utf-8")
if multiple_samples: if self.multiple_samples:
try: for d in data:
data = data[0] yield (d, label, key)
except TypeError: else:
# if the data is a generator yield (data, label, key)
data = six.next(data)
data = tf.convert_to_tensor(data) super(BioGenerator, self).__init__(
self._output_types = (data.dtype, tf.int64, tf.string) biofiles, reader, multiple_samples=multiple_samples, **kwargs
self._output_shapes = (data.shape, tf.TensorShape([]), )
tf.TensorShape([]))
logger.info(
"Initializing a dataset with %d files and %s types "
"and %s shapes", len(self.biofiles), self.output_types,
self.output_shapes)
@property @property
def labels(self): def labels(self):
...@@ -93,34 +78,11 @@ class BioGenerator(object): ...@@ -93,34 +78,11 @@ class BioGenerator(object):
@property @property
def keys(self): def keys(self):
for f in self.biofiles: for f in self.biofiles:
yield str(f.make_path("", "")).encode('utf-8') yield str(f.make_path("", "")).encode("utf-8")
@property
def output_types(self):
return self._output_types
@property @property
def output_shapes(self): def biofiles(self):
return self._output_shapes return self.samples
def __len__(self): def __len__(self):
return len(self.biofiles) return len(self.biofiles)
def __call__(self):
"""A generator function that when called will return the samples.
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
for f, label, key in six.moves.zip(self.biofiles, self.labels,
self.keys):
data = self.load_data(self.database, f)
if self.multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch)
import six
import tensorflow as tf import tensorflow as tf
import random import random
import logging import logging
...@@ -39,14 +38,21 @@ class Generator: ...@@ -39,14 +38,21 @@ class Generator:
self.epoch = 0 self.epoch = 0
self.shuffle_on_epoch_end = shuffle_on_epoch_end self.shuffle_on_epoch_end = shuffle_on_epoch_end
# load one data to get its type and shape # load samples until one of them is not empty
dlk = self.reader(self.samples[0]) # this data is used to get the type and shape
if self.multiple_samples: for sample in self.samples:
try: try:
dlk = dlk[0] dlk = self.reader(sample)
except TypeError: if self.multiple_samples:
# if the data is a generator try:
dlk = six.next(dlk) dlk = dlk[0]
except TypeError:
# if the data is a generator
dlk = next(dlk)
except StopIteration:
continue
else:
break
# Creating a "fake" dataset just to get the types and shapes # Creating a "fake" dataset just to get the types and shapes
dataset = tf.data.Dataset.from_tensors(dlk) dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = dataset.output_types self._output_types = dataset.output_types
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment