Commit 6b189326 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

refactor the bio generator

parent c2ee69bf
import os
import six
import tensorflow as tf
from bob.bio.base.tools.grid import indices
from bob.bio.base import read_original_data as _read_original_data
from bob.bio.base import read_original_data
def make_output_path(output_dir, key):
"""Returns an output path used for saving keys. You need to make sure the
directories leading to this output path exist.
Parameters
----------
output_dir : str
The root directory to save the results
key : str
The key of the sample. Usually biofile.make_path("", "")
Returns
-------
str
The path for the provided key.
"""
return os.path.join(output_dir, key + '.hdf5')
def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
read_original_data=None, biofile_to_label=None,
multiple_samples=False, force=False):
def bio_generator(database, biofiles, load_data=None, biofile_to_label=None,
multiple_samples=False):
"""Returns a generator and its output types and shapes based on
bob.bio.base databases.
......@@ -34,16 +12,11 @@ def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
----------
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
groups : [str]
List of groups. Can be any permutation of ``('world', 'dev', 'eval')``
number_of_parallel_jobs : int
The number of parallel jobs that the script has ran with. This is used
to split the number files into array jobs.
output_dir : str
The root directory where the data will be saved.
read_original_data : :obj:`object`, optional
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
load_data : :obj:`object`, optional
A callable with the signature of
``data = read_original_data(biofile, directory, extension)``.
``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is used by default.
biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``.
......@@ -52,8 +25,6 @@ def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to treat video
databases as image databases.
force : bool, optional
If true, all files will be overwritten.
Returns
-------
......@@ -65,25 +36,22 @@ def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
output_shapes : (tf.TensorShape, tf.TensorShape, tf.TensorShape)
The shapes of the returned samples.
"""
if read_original_data is None:
read_original_data = _read_original_data
if load_data is None:
def load_data(database, biofile):
data = read_original_data(
biofile,
database.original_directory,
database.original_extension)
return data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
biofiles = list(database.all_files(groups))
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
labels = (biofile_to_label(f) for f in biofiles)
keys = (str(f.make_path("", "")) for f in biofiles)
def generator():
for f, label, key in six.moves.zip(biofiles, labels, keys):
outpath = make_output_path(output_dir, key)
if not force and os.path.isfile(outpath):
continue
data = read_original_data(f, database.original_directory,
database.original_extension)
data = load_data(database, f)
# labels
if multiple_samples:
for d in data:
......@@ -92,8 +60,8 @@ def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
yield (data, label, key)
# load one data to get its type and shape
data = read_original_data(biofiles[0], database.original_directory,
database.original_extension)
data = load_data(biofiles[0], database.original_directory,
database.original_extension)
if multiple_samples:
try:
data = data[0]
......
......@@ -53,9 +53,8 @@ The configuration files should have the following objects totally:
An estimator instance that represents the neural network.
database : :any:`bob.bio.base.database.BioDatabase`
A bio database. Its original_directory must point to the correct path.
groups : [str]
A list of groups to evaluate. Can be any permutation of
``('world', 'dev', 'eval')``.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
bio_predict_input_fn : callable
A callable with the signature of
``input_fn = bio_predict_input_fn(generator, output_types, output_shapes)``
......@@ -65,9 +64,9 @@ The configuration files should have the following objects totally:
# Optional objects:
read_original_data : callable
load_data : :obj:`object`, optional
A callable with the signature of
``data = read_original_data(biofile, directory, extension)``.
``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is used by default.
hooks : [:any:`tf.train.SessionRunHook`]
Optional hooks that you may want to attach to the predictions.
......@@ -83,6 +82,7 @@ An example configuration for a trained model and its evaluation could be::
estimator = tf.estimator.Estimator(model_fn, model_dir)
groups = ['dev']
biofiles = database.all_files(groups)
# the ``dataset = tf.data.Dataset.from_generator(generator, output_types,
......@@ -113,13 +113,33 @@ from collections import defaultdict
import numpy as np
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file, save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.learn.tensorflow.dataset.bio import make_output_path, bio_generator
from bob.learn.tensorflow.dataset.bio import bio_generator
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def make_output_path(output_dir, key):
"""Returns an output path used for saving keys. You need to make sure the
directories leading to this output path exist.
Parameters
----------
output_dir : str
The root directory to save the results
key : str
The key of the sample. Usually biofile.make_path("", "")
Returns
-------
str
The path for the provided key.
"""
return os.path.join(output_dir, key + '.hdf5')
def save_predictions(pool, output_dir, key, pred_buffer):
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
......@@ -150,7 +170,7 @@ def main(argv=None):
force = get_from_config_or_commandline(
config, 'force', args, defaults)
hooks = getattr(config, 'hooks', None)
read_original_data = getattr(config, 'read_original_data', None)
load_data = getattr(config, 'load_data', None)
# Sets-up logging
set_verbosity_level(logger, verbosity)
......@@ -158,15 +178,18 @@ def main(argv=None):
# required arguments
estimator = config.estimator
database = config.database
groups = config.groups
biofiles = config.biofiles
bio_predict_input_fn = config.bio_predict_input_fn
output_dir = get_from_config_or_commandline(
config, 'output_dir', args, defaults, False)
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
generator, output_types, output_shapes = bio_generator(
database, groups, number_of_parallel_jobs, output_dir,
read_original_data=read_original_data, biofile_to_label=None,
multiple_samples=multiple_samples, force=force)
database, biofiles, load_data=load_data,
biofile_to_label=None, multiple_samples=multiple_samples, force=force)
predict_input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment