Commit 2404d630 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'batch-tfrecord' into 'master'

Created a function that batches a tf-record in order and apply data augmentation

See merge request !29
parents b9c51932 ddbb1251
Pipeline #13597 failed with stages
in 18 minutes and 35 seconds
......@@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type,
features['key'] = key
return features, labels
def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type,
batch_size, epochs=1,
gray_scale=False,
output_shape=None,
random_flip=False,
random_brightness=False,
random_contrast=False,
random_saturation=False,
per_image_normalization=True):
"""
Dump in order batches from a list of tf-record files
**Parameters**
tfrecord_filenames:
List containing the tf-record paths
data_shape:
Samples shape saved in the tf-record
data_type:
tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types)
batch_size:
Size of the batch
epochs:
Number of epochs to be batched
"""
dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape,
data_type,
gray_scale=gray_scale,
output_shape=output_shape,
random_flip=random_flip,
random_brightness=random_brightness,
random_contrast=random_contrast,
random_saturation=random_saturation,
per_image_normalization=per_image_normalization)
dataset = dataset.batch(batch_size).repeat(epochs)
data, labels, key = dataset.make_one_shot_iterator().get_next()
features = dict()
features['data'] = data
features['key'] = key
return features, labels
#!/usr/bin/env python
"""Script that converts bob.db.lfw database to TF records
Usage:
%(prog)s <data-path> <output-file> [--extension=<arg> --protocol=<arg> --data-type=<arg> --verbose]
%(prog)s --help
%(prog)s --version
Options:
-h --help show this help message and exit
<data-path> Path that contains the features
--extension=<arg> Default feature extension [default: .hdf5]
--protocol=<arg> One of the LFW protocols [default: view1]
--data-type=<arg> TFRecord data type [default: uint8]
The possible protocol options are the following:
'view1', 'fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10'
More details about our interface to LFW database can be found in
https://www.idiap.ch/software/bob/docs/bob/bob.db.lfw/master/index.html.
"""
import tensorflow as tf
from bob.io.base import create_directories_safe
from bob.bio.base.utils import load, read_config_file
from bob.core.log import setup, set_verbosity_level
import bob.db.lfw
import os
import bob.io.image
import bob.io.base
import numpy
logger = setup(__name__)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def file_to_label(client_ids, f):
return client_ids[str(f.client_id)]
def get_pairs(all_pairs, match=True):
enroll = []
probe = []
for p in all_pairs:
if p.is_match == match:
enroll.append(p.enroll_file)
probe.append(p.probe_file)
return enroll, probe
def main(argv=None):
from docopt import docopt
args = docopt(__doc__, version='')
data_path = args['<data-path>']
output_file = args['<output-file>']
extension = args['--extension']
protocol = args['--protocol']
data_type = args['--data-type']
# Sets-up logging
if args['--verbose']:
verbosity = 2
set_verbosity_level(logger, verbosity)
# Loading LFW models
database = bob.db.lfw.Database()
enroll, probe = get_pairs(database.pairs(protocol=protocol), match=True)
#client_ids = list(set([f.client_id for f in all_pairs]))
client_ids = list(set([f.client_id for f in enroll] + [f.client_id for f in probe]))
client_ids = dict(zip(client_ids, range(len(client_ids))))
create_directories_safe(os.path.dirname(output_file))
n_files = len(enroll)
with tf.python_io.TFRecordWriter(output_file) as writer:
for e, p, i in zip(enroll, probe, range(len(enroll)) ):
logger.info('Processing pair %d out of %d', i + 1, n_files)
if os.path.exists(e.make_path(data_path, extension)) and os.path.exists(p.make_path(data_path, extension)):
for f in [e, p]:
path = f.make_path(data_path, extension)
data = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype(data_type)
data = data.tostring()
feature = {'train/data': _bytes_feature(data),
'train/label': _int64_feature(file_to_label(client_ids, f))}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
else:
logger.debug("... Processing original data file '{0}' was not successful".format(path))
if __name__ == '__main__':
main()
......@@ -50,7 +50,6 @@ setup(
'bob_tf_compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main',
'bob_tf_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main',
'bob_tf_load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main',
'bob_tf_lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main',
'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main',
'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main',
'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment