Skip to content
Snippets Groups Projects

Add support for databases which contain more than one sample in one file

Closed Amir MOHAMMADI requested to merge mutilsample into master
1 unresolved thread
1 file
+ 62
65
Compare changes
  • Side-by-side
  • Inline
@@ -69,6 +69,13 @@ The config files should have the following objects totally:
@@ -69,6 +69,13 @@ The config files should have the following objects totally:
# Shuffle the files before writing them into a tfrecords. [default: False]
# Shuffle the files before writing them into a tfrecords. [default: False]
shuffle = True
shuffle = True
 
# Whether the each file contains one sample or more. [default: True] If
 
# this is False, the loaded samples from a file are iterated over and each
 
# of them is saved as an independent feature.
 
one_file_one_sample = True
 
 
# Converts the read data to this format. [default: float32]
 
data_type = "float32"
"""
"""
from __future__ import absolute_import
from __future__ import absolute_import
@@ -81,7 +88,6 @@ from bob.io.base import create_directories_safe
@@ -81,7 +88,6 @@ from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file
from bob.bio.base.utils import load, read_config_file
from bob.core.log import setup, set_verbosity_level
from bob.core.log import setup, set_verbosity_level
logger = setup(__name__)
logger = setup(__name__)
import numpy
def _bytes_feature(value):
def _bytes_feature(value):
@@ -92,73 +98,64 @@ def _int64_feature(value):
@@ -92,73 +98,64 @@ def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bob2skimage(bob_image):
def write_a_sample(writer, data, label):
"""
feature = {'train/data': _bytes_feature(data.tostring()),
Convert bob color image to the skcit image
'train/label': _int64_feature(label)}
"""
example = tf.train.Example(features=tf.train.Features(feature=feature))
if bob_image.ndim==2:
writer.write(example.SerializeToString())
skimage = numpy.zeros(shape=(bob_image.shape[0], bob_image.shape[1], 1))
skimage[:, :, 0] = bob_image
else:
skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
skimage[:, :, 2] = bob_image[0, :, :]
skimage[:, :, 1] = bob_image[1, :, :]
skimage[:, :, 0] = bob_image[2, :, :]
return skimage
def main(argv=None):
def main(argv=None):
from docopt import docopt
from docopt import docopt
import os
import os
import sys
import sys
import pkg_resources
import pkg_resources
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version
version = pkg_resources.require('bob.learn.tensorflow')[0].version
args = docopt(docs, argv=argv, version=version)
args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>']
config_files = args['<config_files>']
config = read_config_file(config_files)
config = read_config_file(config_files)
# Sets-up logging
# Sets-up logging
verbosity = getattr(config, 'verbose', 0)
verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity)
set_verbosity_level(logger, verbosity)
database = config.database
database = config.database
data_dir, output_dir = config.data_dir, config.output_dir
data_dir, output_dir = config.data_dir, config.output_dir
file_to_label = config.file_to_label
file_to_label = config.file_to_label
reader = getattr(config, 'reader', load)
reader = getattr(config, 'reader', load)
groups = getattr(config, 'groups', ['world'])
groups = getattr(config, 'groups', ['world'])
data_extension = getattr(config, 'data_extension', '.hdf5')
data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False)
shuffle = getattr(config, 'shuffle', False)
one_file_one_sample = getattr(config, 'one_file_one_sample', True)
data_type = getattr(config, 'data_type', "float32")
data_type = getattr(config, 'data_type', "float32")
create_directories_safe(output_dir)
create_directories_safe(output_dir)
if not isinstance(groups, (list, tuple)):
if not isinstance(groups, (list, tuple)):
groups = [groups]
groups = [groups]
for group in groups:
for group in groups:
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
files = database.all_files(groups=group)
files = database.all_files(groups=group)
if shuffle:
if shuffle:
random.shuffle(files)
random.shuffle(files)
n_files = len(files)
n_files = len(files)
with tf.python_io.TFRecordWriter(output_file) as writer:
with tf.python_io.TFRecordWriter(output_file) as writer:
for i, f in enumerate(files):
for i, f in enumerate(files):
logger.info('Processing file %d out of %d', i + 1, n_files)
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
path = f.make_path(data_dir, data_extension)
img = bob2skimage(reader(path)).astype(data_type)
data = reader(path).astype(data_type)
img = img.reshape((list(img.shape) + [1]))
label = file_to_label(f)
data = img.tostring()
if one_file_one_sample:
feature = {'train/data': _bytes_feature(data),
write_a_sample(writer, data, label)
'train/label': _int64_feature(file_to_label(f))}
else:
for sample in data:
example = tf.train.Example(features=tf.train.Features(feature=feature))
write_a_sample(writer, sample, label)
writer.write(example.SerializeToString())
if __name__ == '__main__':
if __name__ == '__main__':
main()
main()
Loading