Skip to content
Snippets Groups Projects
Commit e1992171 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Convert the bio generator to a class

parent 467d4d05
No related branches found
No related tags found
1 merge request!33Changes to the biogenerator
import six
import tensorflow as tf
from bob.bio.base import read_original_data
import logging
logger = logging.getLogger(__name__)
def bio_generator(database, biofiles, load_data=None, biofile_to_label=None,
multiple_samples=False, repeat=False):
"""Returns a generator and its output types and shapes based on
bob.bio.base databases.
Parameters
class BioGenerator(object):
"""A generator class which wraps bob.bio.base databases so that they can
be used with tf.data.Dataset.from_generator
Attributes
----------
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
epoch : int
The number of epochs that have been passed so far.
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
load_data : :obj:`object`, optional
A callable with the signature of
``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is used by default.
biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
:any:`bob.bio.base.read_original_data` is wrapped to be used by
default.
multiple_samples : bool, optional
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to treat video
databases as image databases.
repeat : bool, optional
If True, the samples are repeated forever.
Returns
-------
generator : object
A generator function that when called will return the samples. The
samples will be like ``(data, label, key)``.
multiple samples. This is useful for when you want to for example treat
video databases as image databases.
repeat : :obj:`int`, optional
The samples are repeated ``repeat`` times. ``-1`` will make it repeat
forever.
output_types : (object, object, object)
The types of the returned samples.
output_shapes : (tf.TensorShape, tf.TensorShape, tf.TensorShape)
The shapes of the returned samples.
"""
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
labels = (biofile_to_label(f) for f in biofiles)
keys = (str(f.make_path("", "")) for f in biofiles)
def generator():
def __init__(self, database, biofiles, load_data=None,
biofile_to_label=None, multiple_samples=False, repeat=1):
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
self.labels = (biofile_to_label(f) for f in biofiles)
self.keys = (str(f.make_path("", "")) for f in biofiles)
self.database = database
self.biofiles = biofiles
self.load_data = load_data
self.biofile_to_label = biofile_to_label
self.multiple_samples = multiple_samples
self.repeat = repeat
self.epoch = 0
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
self._output_types = (data.dtype, tf.int64, tf.string)
self._output_shapes = (
data.shape, tf.TensorShape([]), tf.TensorShape([]))
logger.debug("Initializing a dataset with %d files and %s types "
"and %s shapes", len(self.biofiles), self.output_types,
self.output_shapes)
@property
def output_types(self):
return self._output_types
@property
def output_shapes(self):
return self._output_shapes
def __call__(self):
"""A generator function that when called will return the samples.
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
while True:
for f, label, key in six.moves.zip(biofiles, labels, keys):
data = load_data(database, f)
for f, label, key in six.moves.zip(
self.biofiles, self.labels, self.keys):
data = self.load_data(self.database, f)
# labels
if multiple_samples:
if self.multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
if not repeat:
self.epoch += 1
logger.info("Elapsed %d epochs", self.epoch)
if self.repeat != -1 and self.epoch >= self.repeat:
break
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
output_types = (data.dtype, tf.int64, tf.string)
output_shapes = (data.shape, tf.TensorShape([]), tf.TensorShape([]))
return (generator, output_types, output_shapes)
......@@ -89,7 +89,7 @@ An example configuration for a trained model and its evaluation could be::
# output_shapes)`` line is mandatory in the function below. You have to
# create it in your configuration file since you want it to be created in
# the same graph as your model.
def bio_predict_input_fn(generator,output_types, output_shapes):
def bio_predict_input_fn(generator, output_types, output_shapes):
def input_fn():
dataset = tf.data.Dataset.from_generator(generator, output_types,
output_shapes)
......@@ -116,7 +116,7 @@ from bob.bio.base.utils import read_config_file, save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.learn.tensorflow.dataset.bio import bio_generator
from bob.learn.tensorflow.dataset.bio import BioGenerator
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
......@@ -172,6 +172,9 @@ def main(argv=None):
hooks = getattr(config, 'hooks', None)
load_data = getattr(config, 'load_data', None)
# TODO(amir): implement force and pre-filtering
raise ValueError("This script is not fully implemented yet!")
# Sets-up logging
set_verbosity_level(logger, verbosity)
......@@ -187,12 +190,12 @@ def main(argv=None):
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
generator, output_types, output_shapes = bio_generator(
generator = BioGenerator(
database, biofiles, load_data=load_data,
biofile_to_label=None, multiple_samples=multiple_samples, force=force)
multiple_samples=multiple_samples)
predict_input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)
predict_input_fn = bio_predict_input_fn(
generator, generator.output_types, generator.output_shapes)
predictions = estimator.predict(
predict_input_fn,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment