diff --git a/bob/learn/tensorflow/dataset/tfrecords.py b/bob/learn/tensorflow/dataset/tfrecords.py index 84c21ab9f822bfe5311a7f1c2a44e7924daee048..faa0938a23db82957339149dcf914ce270a1208e 100644 --- a/bob/learn/tensorflow/dataset/tfrecords.py +++ b/bob/learn/tensorflow/dataset/tfrecords.py @@ -291,4 +291,55 @@ def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, features['key'] = key return features, labels + + +def batch_data_and_labels_image_augmentation(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=1, + gray_scale=False, + output_shape=None, + random_flip=False, + random_brightness=False, + random_contrast=False, + random_saturation=False, + per_image_normalization=True): + """ + Dump in order batches from a list of tf-record files + + **Parameters** + + tfrecord_filenames: + List containing the tf-record paths + + data_shape: + Samples shape saved in the tf-record + + data_type: + tf data type(https://www.tensorflow.org/versions/r0.12/resources/dims_types#data_types) + + batch_size: + Size of the batch + + epochs: + Number of epochs to be batched + + """ + + dataset = create_dataset_from_records_with_augmentation(tfrecord_filenames, data_shape, + data_type, + gray_scale=gray_scale, + output_shape=output_shape, + random_flip=random_flip, + random_brightness=random_brightness, + random_contrast=random_contrast, + random_saturation=random_saturation, + per_image_normalization=per_image_normalization) + + dataset = dataset.batch(batch_size).repeat(epochs) + + data, labels, key = dataset.make_one_shot_iterator().get_next() + features = dict() + features['data'] = data + features['key'] = key + + return features, labels diff --git a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py b/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py deleted file mode 100755 index 7999b635731baffc9800ead015581221667a31e3..0000000000000000000000000000000000000000 --- a/bob/learn/tensorflow/script/lfw_db_to_tfrecords.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python - -"""Script that converts bob.db.lfw database to TF records - -Usage: - %(prog)s <data-path> <output-file> [--extension=<arg> --protocol=<arg> --data-type=<arg> --verbose] - %(prog)s --help - %(prog)s --version - -Options: - -h --help show this help message and exit - <data-path> Path that contains the features - --extension=<arg> Default feature extension [default: .hdf5] - --protocol=<arg> One of the LFW protocols [default: view1] - --data-type=<arg> TFRecord data type [default: uint8] - - -The possible protocol options are the following: - 'view1', 'fold1', 'fold2', 'fold3', 'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10' - -More details about our interface to LFW database can be found in -https://www.idiap.ch/software/bob/docs/bob/bob.db.lfw/master/index.html. - - -""" - -import tensorflow as tf -from bob.io.base import create_directories_safe -from bob.bio.base.utils import load, read_config_file -from bob.core.log import setup, set_verbosity_level -import bob.db.lfw -import os -import bob.io.image -import bob.io.base -import numpy - -logger = setup(__name__) - - -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - -def file_to_label(client_ids, f): - return client_ids[str(f.client_id)] - -def get_pairs(all_pairs, match=True): - - enroll = [] - probe = [] - for p in all_pairs: - if p.is_match == match: - enroll.append(p.enroll_file) - probe.append(p.probe_file) - - return enroll, probe - - -def main(argv=None): - from docopt import docopt - args = docopt(__doc__, version='') - - data_path = args['<data-path>'] - output_file = args['<output-file>'] - extension = args['--extension'] - protocol = args['--protocol'] - data_type = args['--data-type'] - - # Sets-up logging - if args['--verbose']: - verbosity = 2 - set_verbosity_level(logger, verbosity) - - # Loading LFW models - database = bob.db.lfw.Database() - enroll, probe = get_pairs(database.pairs(protocol=protocol), match=True) - #client_ids = list(set([f.client_id for f in all_pairs])) - - client_ids = list(set([f.client_id for f in enroll] + [f.client_id for f in probe])) - client_ids = dict(zip(client_ids, range(len(client_ids)))) - - create_directories_safe(os.path.dirname(output_file)) - - n_files = len(enroll) - with tf.python_io.TFRecordWriter(output_file) as writer: - for e, p, i in zip(enroll, probe, range(len(enroll)) ): - logger.info('Processing pair %d out of %d', i + 1, n_files) - - if os.path.exists(e.make_path(data_path, extension)) and os.path.exists(p.make_path(data_path, extension)): - for f in [e, p]: - path = f.make_path(data_path, extension) - data = bob.io.image.to_matplotlib(bob.io.base.load(path)).astype(data_type) - data = data.tostring() - - feature = {'train/data': _bytes_feature(data), - 'train/label': _int64_feature(file_to_label(client_ids, f))} - - example = tf.train.Example(features=tf.train.Features(feature=feature)) - writer.write(example.SerializeToString()) - else: - logger.debug("... Processing original data file '{0}' was not successful".format(path)) - -if __name__ == '__main__': - main() diff --git a/setup.py b/setup.py index 58a11fcdfaacee564021490a98d6ae43e6ec88b9..1ee1774bd46d277d363f980ad309bf6bb8b1b0a6 100755 --- a/setup.py +++ b/setup.py @@ -50,7 +50,6 @@ setup( 'bob_tf_compute_statistics.py = bob.learn.tensorflow.script.compute_statistics:main', 'bob_tf_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main', 'bob_tf_load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main', - 'bob_tf_lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main', 'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main', 'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main', 'bob_tf_predict_generic = bob.learn.tensorflow.script.predict_generic:main',