Commit ddbb1251 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Removing LFW tfrecord generator

parent d12b00da
Pipeline #13594 failed with stages
in 19 minutes and 46 seconds
#!/usr/bin/env python
"""Script that converts bob.db.lfw database to TF records
%(prog)s <data-path> <output-file> [--extension=<arg> --protocol=<arg> --data-type=<arg> --verbose]
%(prog)s --help
%(prog)s --version
-h --help show this help message and exit
<data-path> Path that contains the features
--extension=<arg> Default feature extension [default: .hdf5]
--protocol=<arg> One of the LFW protocols [default: view1]
--data-type=<arg> TFRecord data type [default: uint8]
The possible protocol options are the following:
'view1', 'fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10'
More details about our interface to LFW database can be found in
import tensorflow as tf
from import create_directories_safe
from import load, read_config_file
from bob.core.log import setup, set_verbosity_level
import bob.db.lfw
import os
import numpy
logger = setup(__name__)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def file_to_label(client_ids, f):
return client_ids[str(f.client_id)]
def get_pairs(all_pairs, match=True):
enroll = []
probe = []
for p in all_pairs:
if p.is_match == match:
return enroll, probe
def main(argv=None):
from docopt import docopt
args = docopt(__doc__, version='')
data_path = args['<data-path>']
output_file = args['<output-file>']
extension = args['--extension']
protocol = args['--protocol']
data_type = args['--data-type']
# Sets-up logging
if args['--verbose']:
verbosity = 2
set_verbosity_level(logger, verbosity)
# Loading LFW models
database = bob.db.lfw.Database()
enroll, probe = get_pairs(database.pairs(protocol=protocol), match=True)
#client_ids = list(set([f.client_id for f in all_pairs]))
client_ids = list(set([f.client_id for f in enroll] + [f.client_id for f in probe]))
client_ids = dict(zip(client_ids, range(len(client_ids))))
n_files = len(enroll)
with tf.python_io.TFRecordWriter(output_file) as writer:
for e, p, i in zip(enroll, probe, range(len(enroll)) ):'Processing pair %d out of %d', i + 1, n_files)
if os.path.exists(e.make_path(data_path, extension)) and os.path.exists(p.make_path(data_path, extension)):
for f in [e, p]:
path = f.make_path(data_path, extension)
data =
data = data.tostring()
feature = {'data': _bytes_feature(data),
'label': _int64_feature(file_to_label(client_ids, f)),
'key': _bytes_feature(str(f.path)),
example = tf.train.Example(features=tf.train.Features(feature=feature))
logger.debug("... Processing original data file '{0}' was not successful".format(path))
if __name__ == '__main__':
......@@ -50,7 +50,6 @@ setup(
' = bob.learn.tensorflow.script.compute_statistics:main',
'bob_tf_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
' = bob.learn.tensorflow.script.load_and_debug:main',
' = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment