Skip to content
Snippets Groups Projects
Commit 77c46255 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add support for databases which contain more than one sample in one file

parents 578f27a2 2f2a62c9
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -62,6 +62,10 @@ The config files should have the following objects totally:
reader = Extractor().read_feature
# or
from bob.bio.base.utils import load as reader
# or a reader that casts images to uint8:
def reader(path):
data = bob.bio.base.utils.load(path)
return data.astype("uint8")
# extension of the preprocessed files. [default: '.hdf5']
data_extension = '.hdf5'
......@@ -69,6 +73,10 @@ The config files should have the following objects totally:
# Shuffle the files before writing them into a tfrecords. [default: False]
shuffle = True
# Whether the each file contains one sample or more. [default: True] If
# this is False, the loaded samples from a file are iterated over and each
# of them is saved as an independent feature.
one_file_one_sample = True
"""
from __future__ import absolute_import
......@@ -81,7 +89,6 @@ from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
import numpy
def _bytes_feature(value):
......@@ -92,73 +99,63 @@ def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bob2skimage(bob_image):
"""
Convert bob color image to the skcit image
"""
if bob_image.ndim==2:
skimage = numpy.zeros(shape=(bob_image.shape[0], bob_image.shape[1], 1))
skimage[:, :, 0] = bob_image
else:
skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
skimage[:, :, 2] = bob_image[0, :, :]
skimage[:, :, 1] = bob_image[1, :, :]
skimage[:, :, 0] = bob_image[2, :, :]
def write_a_sample(writer, data, label):
feature = {'train/data': _bytes_feature(data.tostring()),
'train/label': _int64_feature(label)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
return skimage
def main(argv=None):
from docopt import docopt
import os
import sys
import pkg_resources
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# Sets-up logging
verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity)
database = config.database
data_dir, output_dir = config.data_dir, config.output_dir
file_to_label = config.file_to_label
reader = getattr(config, 'reader', load)
groups = getattr(config, 'groups', ['world'])
data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False)
data_type = getattr(config, 'data_type', "float32")
create_directories_safe(output_dir)
if not isinstance(groups, (list, tuple)):
groups = [groups]
for group in groups:
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
files = database.all_files(groups=group)
if shuffle:
random.shuffle(files)
n_files = len(files)
with tf.python_io.TFRecordWriter(output_file) as writer:
for i, f in enumerate(files):
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
img = bob2skimage(reader(path)).astype(data_type)
img = img.reshape((list(img.shape) + [1]))
data = img.tostring()
feature = {'train/data': _bytes_feature(data),
'train/label': _int64_feature(file_to_label(f))}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
from docopt import docopt
import os
import sys
import pkg_resources
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config = read_config_file(config_files)
# Sets-up logging
verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity)
database = config.database
data_dir, output_dir = config.data_dir, config.output_dir
file_to_label = config.file_to_label
reader = getattr(config, 'reader', load)
groups = getattr(config, 'groups', ['world'])
data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False)
one_file_one_sample = getattr(config, 'one_file_one_sample', True)
create_directories_safe(output_dir)
if not isinstance(groups, (list, tuple)):
groups = [groups]
for group in groups:
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
files = database.all_files(groups=group)
if shuffle:
random.shuffle(files)
n_files = len(files)
with tf.python_io.TFRecordWriter(output_file) as writer:
for i, f in enumerate(files):
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
data = reader(path)
label = file_to_label(f)
if one_file_one_sample:
write_a_sample(writer, data, label)
else:
for sample in data:
write_a_sample(writer, sample, label)
if __name__ == '__main__':
main()
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment