Commit e1992171 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Convert the bio generator to a class

parent 467d4d05
import six import six
import tensorflow as tf import tensorflow as tf
from bob.bio.base import read_original_data from bob.bio.base import read_original_data
import logging
logger = logging.getLogger(__name__)
def bio_generator(database, biofiles, load_data=None, biofile_to_label=None,
multiple_samples=False, repeat=False):
"""Returns a generator and its output types and shapes based on
bob.bio.base databases.
Parameters class BioGenerator(object):
"""A generator class which wraps bob.bio.base databases so that they can
be used with tf.data.Dataset.from_generator
Attributes
---------- ----------
database : :any:`bob.bio.base.database.BioDatabase` biofile_to_label : :obj:`object`, optional
The database that you want to use. A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
biofiles : [:any:`bob.bio.base.database.BioFile`] biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files . The list of the bio files .
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
epoch : int
The number of epochs that have been passed so far.
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
load_data : :obj:`object`, optional load_data : :obj:`object`, optional
A callable with the signature of A callable with the signature of
``data = load_data(database, biofile)``. ``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is used by default. :any:`bob.bio.base.read_original_data` is wrapped to be used by
biofile_to_label : :obj:`object`, optional default.
A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
multiple_samples : bool, optional multiple_samples : bool, optional
If true, it assumes that the bio database's samples actually contain If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to treat video multiple samples. This is useful for when you want to for example treat
databases as image databases. video databases as image databases.
repeat : bool, optional repeat : :obj:`int`, optional
If True, the samples are repeated forever. The samples are repeated ``repeat`` times. ``-1`` will make it repeat
forever.
Returns
-------
generator : object
A generator function that when called will return the samples. The
samples will be like ``(data, label, key)``.
output_types : (object, object, object) output_types : (object, object, object)
The types of the returned samples. The types of the returned samples.
output_shapes : (tf.TensorShape, tf.TensorShape, tf.TensorShape) output_shapes : (tf.TensorShape, tf.TensorShape, tf.TensorShape)
The shapes of the returned samples. The shapes of the returned samples.
""" """
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
labels = (biofile_to_label(f) for f in biofiles)
keys = (str(f.make_path("", "")) for f in biofiles)
def generator(): def __init__(self, database, biofiles, load_data=None,
biofile_to_label=None, multiple_samples=False, repeat=1):
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
self.labels = (biofile_to_label(f) for f in biofiles)
self.keys = (str(f.make_path("", "")) for f in biofiles)
self.database = database
self.biofiles = biofiles
self.load_data = load_data
self.biofile_to_label = biofile_to_label
self.multiple_samples = multiple_samples
self.repeat = repeat
self.epoch = 0
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
self._output_types = (data.dtype, tf.int64, tf.string)
self._output_shapes = (
data.shape, tf.TensorShape([]), tf.TensorShape([]))
logger.debug("Initializing a dataset with %d files and %s types "
"and %s shapes", len(self.biofiles), self.output_types,
self.output_shapes)
@property
def output_types(self):
return self._output_types
@property
def output_shapes(self):
return self._output_shapes
def __call__(self):
"""A generator function that when called will return the samples.
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
while True: while True:
for f, label, key in six.moves.zip(biofiles, labels, keys): for f, label, key in six.moves.zip(
data = load_data(database, f) self.biofiles, self.labels, self.keys):
data = self.load_data(self.database, f)
# labels # labels
if multiple_samples: if self.multiple_samples:
for d in data: for d in data:
yield (d, label, key) yield (d, label, key)
else: else:
yield (data, label, key) yield (data, label, key)
if not repeat: self.epoch += 1
logger.info("Elapsed %d epochs", self.epoch)
if self.repeat != -1 and self.epoch >= self.repeat:
break break
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
output_types = (data.dtype, tf.int64, tf.string)
output_shapes = (data.shape, tf.TensorShape([]), tf.TensorShape([]))
return (generator, output_types, output_shapes)
...@@ -89,7 +89,7 @@ An example configuration for a trained model and its evaluation could be:: ...@@ -89,7 +89,7 @@ An example configuration for a trained model and its evaluation could be::
# output_shapes)`` line is mandatory in the function below. You have to # output_shapes)`` line is mandatory in the function below. You have to
# create it in your configuration file since you want it to be created in # create it in your configuration file since you want it to be created in
# the same graph as your model. # the same graph as your model.
def bio_predict_input_fn(generator,output_types, output_shapes): def bio_predict_input_fn(generator, output_types, output_shapes):
def input_fn(): def input_fn():
dataset = tf.data.Dataset.from_generator(generator, output_types, dataset = tf.data.Dataset.from_generator(generator, output_types,
output_shapes) output_shapes)
...@@ -116,7 +116,7 @@ from bob.bio.base.utils import read_config_file, save ...@@ -116,7 +116,7 @@ from bob.bio.base.utils import read_config_file, save
from bob.bio.base.tools.grid import indices from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.utils.commandline import \ from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline get_from_config_or_commandline
from bob.learn.tensorflow.dataset.bio import bio_generator from bob.learn.tensorflow.dataset.bio import BioGenerator
from bob.core.log import setup, set_verbosity_level from bob.core.log import setup, set_verbosity_level
logger = setup(__name__) logger = setup(__name__)
...@@ -172,6 +172,9 @@ def main(argv=None): ...@@ -172,6 +172,9 @@ def main(argv=None):
hooks = getattr(config, 'hooks', None) hooks = getattr(config, 'hooks', None)
load_data = getattr(config, 'load_data', None) load_data = getattr(config, 'load_data', None)
# TODO(amir): implement force and pre-filtering
raise ValueError("This script is not fully implemented yet!")
# Sets-up logging # Sets-up logging
set_verbosity_level(logger, verbosity) set_verbosity_level(logger, verbosity)
...@@ -187,12 +190,12 @@ def main(argv=None): ...@@ -187,12 +190,12 @@ def main(argv=None):
start, end = indices(biofiles, number_of_parallel_jobs) start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end] biofiles = biofiles[start:end]
generator, output_types, output_shapes = bio_generator( generator = BioGenerator(
database, biofiles, load_data=load_data, database, biofiles, load_data=load_data,
biofile_to_label=None, multiple_samples=multiple_samples, force=force) multiple_samples=multiple_samples)
predict_input_fn = bio_predict_input_fn(generator, predict_input_fn = bio_predict_input_fn(
output_types, output_shapes) generator, generator.output_types, generator.output_shapes)
predictions = estimator.predict( predictions = estimator.predict(
predict_input_fn, predict_input_fn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment