diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py index f4e234bc2df6927b96790388023d9e6dda2c4e6a..7289c6d0fcf15c27f0bbe94f70bd99b62a5094b2 100644 --- a/bob/learn/tensorflow/script/db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/db_to_tfrecords.py @@ -62,6 +62,10 @@ The config files should have the following objects totally: reader = Extractor().read_feature # or from bob.bio.base.utils import load as reader + # or a reader that casts images to uint8: + def reader(path): + data = bob.bio.base.utils.load(path) + return data.astype("uint8") # extension of the preprocessed files. [default: '.hdf5'] data_extension = '.hdf5' @@ -69,6 +73,10 @@ The config files should have the following objects totally: # Shuffle the files before writing them into a tfrecords. [default: False] shuffle = True + # Whether the each file contains one sample or more. [default: True] If + # this is False, the loaded samples from a file are iterated over and each + # of them is saved as an independent feature. + one_file_one_sample = True """ from __future__ import absolute_import @@ -81,7 +89,6 @@ from bob.io.base import create_directories_safe from bob.bio.base.utils import load, read_config_file from bob.core.log import setup, set_verbosity_level logger = setup(__name__) -import numpy def _bytes_feature(value): @@ -92,73 +99,63 @@ def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) -def bob2skimage(bob_image): - """ - Convert bob color image to the skcit image - """ - - if bob_image.ndim==2: - skimage = numpy.zeros(shape=(bob_image.shape[0], bob_image.shape[1], 1)) - skimage[:, :, 0] = bob_image - else: - skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0])) - skimage[:, :, 2] = bob_image[0, :, :] - skimage[:, :, 1] = bob_image[1, :, :] - skimage[:, :, 0] = bob_image[2, :, :] +def write_a_sample(writer, data, label): + feature = {'train/data': _bytes_feature(data.tostring()), + 'train/label': _int64_feature(label)} + + example = tf.train.Example(features=tf.train.Features(feature=feature)) + writer.write(example.SerializeToString()) - return skimage def main(argv=None): - from docopt import docopt - import os - import sys - import pkg_resources - docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} - version = pkg_resources.require('bob.learn.tensorflow')[0].version - args = docopt(docs, argv=argv, version=version) - config_files = args['<config_files>'] - config = read_config_file(config_files) - - # Sets-up logging - verbosity = getattr(config, 'verbose', 0) - set_verbosity_level(logger, verbosity) - - database = config.database - data_dir, output_dir = config.data_dir, config.output_dir - file_to_label = config.file_to_label - - reader = getattr(config, 'reader', load) - groups = getattr(config, 'groups', ['world']) - data_extension = getattr(config, 'data_extension', '.hdf5') - shuffle = getattr(config, 'shuffle', False) - - data_type = getattr(config, 'data_type', "float32") - - create_directories_safe(output_dir) - if not isinstance(groups, (list, tuple)): - groups = [groups] - - for group in groups: - output_file = os.path.join(output_dir, '{}.tfrecords'.format(group)) - files = database.all_files(groups=group) - if shuffle: - random.shuffle(files) - n_files = len(files) - with tf.python_io.TFRecordWriter(output_file) as writer: - for i, f in enumerate(files): - logger.info('Processing file %d out of %d', i + 1, n_files) - - path = f.make_path(data_dir, data_extension) - img = bob2skimage(reader(path)).astype(data_type) - img = img.reshape((list(img.shape) + [1])) - data = img.tostring() - - feature = {'train/data': _bytes_feature(data), - 'train/label': _int64_feature(file_to_label(f))} - - example = tf.train.Example(features=tf.train.Features(feature=feature)) - writer.write(example.SerializeToString()) + from docopt import docopt + import os + import sys + import pkg_resources + docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} + version = pkg_resources.require('bob.learn.tensorflow')[0].version + args = docopt(docs, argv=argv, version=version) + config_files = args['<config_files>'] + config = read_config_file(config_files) + + # Sets-up logging + verbosity = getattr(config, 'verbose', 0) + set_verbosity_level(logger, verbosity) + + database = config.database + data_dir, output_dir = config.data_dir, config.output_dir + file_to_label = config.file_to_label + + reader = getattr(config, 'reader', load) + groups = getattr(config, 'groups', ['world']) + data_extension = getattr(config, 'data_extension', '.hdf5') + shuffle = getattr(config, 'shuffle', False) + one_file_one_sample = getattr(config, 'one_file_one_sample', True) + + create_directories_safe(output_dir) + if not isinstance(groups, (list, tuple)): + groups = [groups] + + for group in groups: + output_file = os.path.join(output_dir, '{}.tfrecords'.format(group)) + files = database.all_files(groups=group) + if shuffle: + random.shuffle(files) + n_files = len(files) + with tf.python_io.TFRecordWriter(output_file) as writer: + for i, f in enumerate(files): + logger.info('Processing file %d out of %d', i + 1, n_files) + + path = f.make_path(data_dir, data_extension) + data = reader(path) + label = file_to_label(f) + + if one_file_one_sample: + write_a_sample(writer, data, label) + else: + for sample in data: + write_a_sample(writer, sample, label) if __name__ == '__main__': - main() + main()