Commit 7ea85d1b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

refactor code

parent 59386bb9
import os
import six
import tensorflow as tf
from bob.bio.base.tools.grid import indices
from bob.bio.base import read_original_data as _read_original_data
def make_output_path(output_dir, key):
return os.path.join(output_dir, key + '.hdf5')
def load_data(biofile, read_original_data, original_directory,
original_extension):
data = read_original_data(biofile, original_directory, original_extension)
return data
def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
read_original_data=None, biofile_to_label=None,
multiple_samples=False, force=False):
if read_original_data is None:
read_original_data = _read_original_data
if biofile_to_label is None:
def biofile_to_label(biofile):
return -1
biofiles = list(database.all_files(groups))
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
labels = (biofile_to_label(f) for f in biofiles)
keys = (str(f.make_path("", "")) for f in biofiles)
def generator():
for f, label, key in six.moves.zip(biofiles, labels, keys):
outpath = make_output_path(output_dir, key)
if not force and os.path.isfile(outpath):
continue
data = load_data(f, read_original_data, database)
# labels
if multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
# load one data to get its type and shape
data = load_data(biofiles[0], read_original_data, database)
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
output_types = (data.dtype, tf.int64, tf.string)
output_shapes = (data.shape, tf.TensorShape([]), tf.TensorShape([]))
return (generator, output_types, output_shapes)
......@@ -110,67 +110,16 @@ import pkg_resources
import os
from multiprocessing import Pool
from collections import defaultdict
import six
import numpy as np
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import read_config_file, save
from bob.bio.base.tools.grid import indices
from bob.learn.tensorflow.utils.commandline import \
get_from_config_or_commandline
from bob.learn.tensorflow.dataset.bio import make_output_path, bio_generator
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
def make_output_path(output_dir, key):
return os.path.join(output_dir, key + '.hdf5')
def bio_generator(database, groups, number_of_parallel_jobs, output_dir,
read_original_data=None, multiple_samples=False,
force=False):
if read_original_data is None:
from bob.bio.base import read_original_data
biofiles = list(database.all_files(groups))
if number_of_parallel_jobs > 1:
start, end = indices(biofiles, number_of_parallel_jobs)
biofiles = biofiles[start:end]
keys = (str(f.make_path("", "")) for f in biofiles)
def load_data(f, read_original_data, database):
data = read_original_data(
f,
database.original_directory,
database.original_extension)
return data
def generator():
for f, key in six.moves.zip(biofiles, keys):
outpath = make_output_path(output_dir, key)
if not force and os.path.isfile(outpath):
continue
data = load_data(f, read_original_data, database)
if multiple_samples:
for d in data:
yield (d, -1, key)
else:
yield (data, -1, key)
# load one data to get its type and shape
data = load_data(biofiles[0], read_original_data, database)
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
output_types = (data.dtype, tf.int64, tf.string)
output_shapes = (data.shape, tf.TensorShape([]), tf.TensorShape([]))
return (generator, output_types, output_shapes)
def save_predictions(pool, output_dir, key, pred_buffer):
outpath = make_output_path(output_dir, key)
create_directories_safe(os.path.dirname(outpath))
......@@ -216,7 +165,8 @@ def main(argv=None):
generator, output_types, output_shapes = bio_generator(
database, groups, number_of_parallel_jobs, output_dir,
read_original_data, multiple_samples, force)
read_original_data=read_original_data, biofile_to_label=None,
multiple_samples=multiple_samples, force=force)
predict_input_fn = bio_predict_input_fn(generator,
output_types, output_shapes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment