Skip to content
Snippets Groups Projects
Commit ddbb1251 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Removing LFW tfrecord generator

parent d12b00da
No related branches found
No related tags found
1 merge request!29Created a function that batches a tf-record in order and apply data augmentation
Pipeline #
This commit is part of merge request !29. Comments created here will be created in the context of that merge request.
#!/usr/bin/env python
"""Script that converts bob.db.lfw database to TF records
Usage:
%(prog)s <data-path> <output-file> [--extension=<arg> --protocol=<arg> --data-type=<arg> --verbose]
%(prog)s --help
%(prog)s --version
Options:
-h --help show this help message and exit
<data-path> Path that contains the features
--extension=<arg> Default feature extension [default: .hdf5]
--protocol=<arg> One of the LFW protocols [default: view1]
--data-type=<arg> TFRecord data type [default: uint8]
The possible protocol options are the following:
'view1', 'fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10'
More details about our interface to LFW database can be found in
https://www.idiap.ch/software/bob/docs/bob/bob.db.lfw/master/index.html.
"""
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file
from bob.core.log import setup, set_verbosity_level
import bob.db.lfw
import os
import bob.io.image
import bob.io.base
import numpy
logger = setup(__name__)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def file_to_label(client_ids, f):
return client_ids[str(f.client_id)]
def get_pairs(all_pairs, match=True):
enroll = []
probe = []
for p in all_pairs:
if p.is_match == match:
enroll.append(p.enroll_file)
probe.append(p.probe_file)
return enroll, probe
def main(argv=None):
from docopt import docopt
args = docopt(__doc__, version='')
data_path = args['<data-path>']
output_file = args['<output-file>']
extension = args['--extension']
protocol = args['--protocol']
data_type = args['--data-type']
# Sets-up logging
if args['--verbose']:
verbosity = 2
set_verbosity_level(logger, verbosity)
# Loading LFW models
database = bob.db.lfw.Database()
enroll, probe = get_pairs(database.pairs(protocol=protocol), match=True)
#client_ids = list(set([f.client_id for f in all_pairs]))
client_ids = list(set([f.client_id for f in enroll] + [f.client_id for f in probe]))
client_ids = dict(zip(client_ids, range(len(client_ids))))
create_directories_safe(os.path.dirname(output_file))
n_files = len(enroll)
with tf.python_io.TFRecordWriter(output_file) as writer:
for e, p, i in zip(enroll, probe, range(len(enroll)) ):
logger.info('Processing pair %d out of %d', i + 1, n_files)
if os.path.exists(e.make_path(data_path, extension)) and os.path.exists(p.make_path(data_path, extension)):
for f in [e, p]:
path = f.make_path(data_path, extension)
data = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype(data_type)
data = data.tostring()
feature = {'data': _bytes_feature(data),
'label': _int64_feature(file_to_label(client_ids, f)),
'key': _bytes_feature(str(f.path)),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
else:
logger.debug("... Processing original data file '{0}' was not successful".format(path))
if __name__ == '__main__':
main()
......@@ -50,7 +50,6 @@ setup(
'bob_tf_compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
'bob_tf_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
'bob_tf_load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
'bob_tf_lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment