Commit 2f2a62c9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add support for databases which contain more than one sample in one file

parent 578f27a2
......@@ -62,6 +62,10 @@ The config files should have the following objects totally:
reader = Extractor().read_feature
# or
from bob.bio.base.utils import load as reader
# or a reader that casts images to uint8:
def reader(path):
data = bob.bio.base.utils.load(path)
return data.astype("uint8")
# extension of the preprocessed files. [default: '.hdf5']
data_extension = '.hdf5'
......@@ -69,6 +73,10 @@ The config files should have the following objects totally:
# Shuffle the files before writing them into a tfrecords. [default: False]
shuffle = True
# Whether the each file contains one sample or more. [default: True] If
# this is False, the loaded samples from a file are iterated over and each
# of them is saved as an independent feature.
one_file_one_sample = True
"""
from __future__ import absolute_import
......@@ -81,7 +89,6 @@ from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
import numpy
def _bytes_feature(value):
......@@ -92,73 +99,63 @@ def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bob2skimage(bob_image):
"""
Convert bob color image to the skcit image
"""
if bob_image.ndim==2:
skimage = numpy.zeros(shape=(bob_image.shape[0], bob_image.shape[1], 1))
skimage[:, :, 0] = bob_image
else:
skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
skimage[:, :, 2] = bob_image[0, :, :]
skimage[:, :, 1] = bob_image[1, :, :]
skimage[:, :, 0] = bob_image[2, :, :]
def write_a_sample(writer, data, label):
feature = {'train/data': _bytes_feature(data.tostring()),
'train/label': _int64_feature(label)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
return skimage
def main(argv=None):
from docopt import docopt
import os
import sys
import pkg_resources
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# Sets-up logging
verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity)
database = config.database
data_dir, output_dir = config.data_dir, config.output_dir
file_to_label = config.file_to_label
reader = getattr(config, 'reader', load)
groups = getattr(config, 'groups', ['world'])
data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False)
data_type = getattr(config, 'data_type', "float32")
create_directories_safe(output_dir)
if not isinstance(groups, (list, tuple)):
groups = [groups]
for group in groups:
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
files = database.all_files(groups=group)
if shuffle:
random.shuffle(files)
n_files = len(files)
with tf.python_io.TFRecordWriter(output_file) as writer:
for i, f in enumerate(files):
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
img = bob2skimage(reader(path)).astype(data_type)
img = img.reshape((list(img.shape) + [1]))
data = img.tostring()
feature = {'train/data': _bytes_feature(data),
'train/label': _int64_feature(file_to_label(f))}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
from docopt import docopt
import os
import sys
import pkg_resources
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# Sets-up logging
verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity)
database = config.database
data_dir, output_dir = config.data_dir, config.output_dir
file_to_label = config.file_to_label
reader = getattr(config, 'reader', load)
groups = getattr(config, 'groups', ['world'])
data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False)
one_file_one_sample = getattr(config, 'one_file_one_sample', True)
create_directories_safe(output_dir)
if not isinstance(groups, (list, tuple)):
groups = [groups]
for group in groups:
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
files = database.all_files(groups=group)
if shuffle:
random.shuffle(files)
n_files = len(files)
with tf.python_io.TFRecordWriter(output_file) as writer:
for i, f in enumerate(files):
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
data = reader(path)
label = file_to_label(f)
if one_file_one_sample:
write_a_sample(writer, data, label)
else:
for sample in data:
write_a_sample(writer, sample, label)
if __name__ == '__main__':
main()
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment