diff --git a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py index 7999b635731baffc9800ead015581221667a31e3..91dc27b1f0a7f642a1acc0d314045118f767edd9 100755 --- a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py @@ -95,8 +95,10 @@ def main(argv=None): data = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype(data_type) data = data.tostring() - feature = {'train/data': _bytes_feature(data), - 'train/label': _int64_feature(file_to_label(client_ids, f))} + feature = {'data': _bytes_feature(data), + 'label': _int64_feature(file_to_label(client_ids, f)), + 'key': _bytes_feature(str(f.path)), + } example = tf.train.Example(features=tf.train.Features(feature=feature)) writer.write(example.SerializeToString())