Skip to content
Snippets Groups Projects
Commit 77c46255 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add support for databases which contain more than one sample in one file

parents 578f27a2 2f2a62c9
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -62,6 +62,10 @@ The config files should have the following objects totally: ...@@ -62,6 +62,10 @@ The config files should have the following objects totally:
reader = Extractor().read_feature reader = Extractor().read_feature
# or # or
from bob.bio.base.utils import load as reader from bob.bio.base.utils import load as reader
# or a reader that casts images to uint8:
def reader(path):
data = bob.bio.base.utils.load(path)
return data.astype("uint8")
# extension of the preprocessed files. [default: '.hdf5'] # extension of the preprocessed files. [default: '.hdf5']
data_extension = '.hdf5' data_extension = '.hdf5'
...@@ -69,6 +73,10 @@ The config files should have the following objects totally: ...@@ -69,6 +73,10 @@ The config files should have the following objects totally:
# Shuffle the files before writing them into a tfrecords. [default: False] # Shuffle the files before writing them into a tfrecords. [default: False]
shuffle = True shuffle = True
# Whether the each file contains one sample or more. [default: True] If
# this is False, the loaded samples from a file are iterated over and each
# of them is saved as an independent feature.
one_file_one_sample = True
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -81,7 +89,6 @@ from bob.io.base import create_directories_safe ...@@ -81,7 +89,6 @@ from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file from bob.bio.base.utils import load, read_config_file
from bob.core.log import setup, set_verbosity_level from bob.core.log import setup, set_verbosity_level
logger = setup(__name__) logger = setup(__name__)
import numpy
def _bytes_feature(value): def _bytes_feature(value):
...@@ -92,73 +99,63 @@ def _int64_feature(value): ...@@ -92,73 +99,63 @@ def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bob2skimage(bob_image): def write_a_sample(writer, data, label):
""" feature = {'train/data': _bytes_feature(data.tostring()),
Convert bob color image to the skcit image 'train/label': _int64_feature(label)}
"""
example = tf.train.Example(features=tf.train.Features(feature=feature))
if bob_image.ndim==2: writer.write(example.SerializeToString())
skimage = numpy.zeros(shape=(bob_image.shape[0], bob_image.shape[1], 1))
skimage[:, :, 0] = bob_image
else:
skimage = numpy.zeros(shape=(bob_image.shape[1], bob_image.shape[2], bob_image.shape[0]))
skimage[:, :, 2] = bob_image[0, :, :]
skimage[:, :, 1] = bob_image[1, :, :]
skimage[:, :, 0] = bob_image[2, :, :]
return skimage
def main(argv=None): def main(argv=None):
from docopt import docopt from docopt import docopt
import os import os
import sys import sys
import pkg_resources import pkg_resources
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version version = pkg_resources.require('bob.learn.tensorflow')[0].version
args = docopt(docs, argv=argv, version=version) args = docopt(docs, argv=argv, version=version)
config_files = args['<config_files>'] config_files = args['<config_files>']
config = read_config_file(config_files) config = read_config_file(config_files)
# Sets-up logging # Sets-up logging
verbosity = getattr(config, 'verbose', 0) verbosity = getattr(config, 'verbose', 0)
set_verbosity_level(logger, verbosity) set_verbosity_level(logger, verbosity)
database = config.database database = config.database
data_dir, output_dir = config.data_dir, config.output_dir data_dir, output_dir = config.data_dir, config.output_dir
file_to_label = config.file_to_label file_to_label = config.file_to_label
reader = getattr(config, 'reader', load) reader = getattr(config, 'reader', load)
groups = getattr(config, 'groups', ['world']) groups = getattr(config, 'groups', ['world'])
data_extension = getattr(config, 'data_extension', '.hdf5') data_extension = getattr(config, 'data_extension', '.hdf5')
shuffle = getattr(config, 'shuffle', False) shuffle = getattr(config, 'shuffle', False)
one_file_one_sample = getattr(config, 'one_file_one_sample', True)
data_type = getattr(config, 'data_type', "float32")
create_directories_safe(output_dir)
create_directories_safe(output_dir) if not isinstance(groups, (list, tuple)):
if not isinstance(groups, (list, tuple)): groups = [groups]
groups = [groups]
for group in groups:
for group in groups: output_file = os.path.join(output_dir, '{}.tfrecords'.format(group))
output_file = os.path.join(output_dir, '{}.tfrecords'.format(group)) files = database.all_files(groups=group)
files = database.all_files(groups=group) if shuffle:
if shuffle: random.shuffle(files)
random.shuffle(files) n_files = len(files)
n_files = len(files) with tf.python_io.TFRecordWriter(output_file) as writer:
with tf.python_io.TFRecordWriter(output_file) as writer: for i, f in enumerate(files):
for i, f in enumerate(files): logger.info('Processing file %d out of %d', i + 1, n_files)
logger.info('Processing file %d out of %d', i + 1, n_files)
path = f.make_path(data_dir, data_extension)
path = f.make_path(data_dir, data_extension) data = reader(path)
img = bob2skimage(reader(path)).astype(data_type) label = file_to_label(f)
img = img.reshape((list(img.shape) + [1]))
data = img.tostring() if one_file_one_sample:
write_a_sample(writer, data, label)
feature = {'train/data': _bytes_feature(data), else:
'train/label': _int64_feature(file_to_label(f))} for sample in data:
write_a_sample(writer, sample, label)
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
if __name__ == '__main__': if __name__ == '__main__':
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment