diff --git a/bob/learn/tensorflow/script/db_to_tfrecords.py b/bob/learn/tensorflow/script/db_to_tfrecords.py index f3b8404432ba987943d5d4fa3fd65fb8ea8328aa..aa5169935749e1022c84cca1a2e3427548a6689d 100644 --- a/bob/learn/tensorflow/script/db_to_tfrecords.py +++ b/bob/learn/tensorflow/script/db_to_tfrecords.py @@ -8,9 +8,9 @@ Usage: %(prog)s --version Arguments: - <config_files> The config files. The config files are loaded in order and - they need to have several objects inside totally. See below - for explanation. + <config_files> The configuration files. The configuration files are loaded + in order and they need to have several objects inside + totally. See below for explanation. Options: -h --help show this help message and exit @@ -21,7 +21,7 @@ Idiap: $ jman submit -i -q q1d -- bin/python %(prog)s <config_files>... -The config files should have the following objects totally: +The configuration files should have the following objects totally: ## Required objects: @@ -83,6 +83,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import random +# import pkg_resources so that bob imports work properly: +import pkg_resources import tensorflow as tf from bob.io.base import create_directories_safe @@ -112,7 +114,6 @@ def main(argv=None): from docopt import docopt import os import sys - import pkg_resources docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} version = pkg_resources.require('bob.learn.tensorflow')[0].version args = docopt(docs, argv=argv, version=version) diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py index 79ffce829cc60d2e3ca5460f9849db0fe1f458f2..2c36f55bf81c450dabb632ddc62d9039d57b44a7 100644 --- a/bob/learn/tensorflow/script/eval_generic.py +++ b/bob/learn/tensorflow/script/eval_generic.py @@ -1,66 +1,91 @@ #!/usr/bin/env python -"""Trains the VGG-audio network on the AVspoof database. +"""Evaluates networks trained with tf.train.MonitoredTrainingSession Usage: - %(prog)s [options] <checkpoint_dir> <eval_tfrecords>... + %(prog)s [options] <config_files>... + %(prog)s --help + %(prog)s --version + +Arguments: + <config_files> The configuration files. The configuration files are loaded + in order and they need to have several objects inside + totally. See below for explanation. Options: - -h --help Show this screen. - --eval-dir PATH [default: /idiap/user/amohammadi/avspoof/specgram/avspoof-simple-cnn-eval] - --input-shape N [default: (50, 1024, 1)] - --batch-size N [default: 50] - --run-once Evaluate the model once only. - --eval-interval-secs N Interval to evaluations. [default: 300] - --num-examples N Number of examples to run. [default: None] Provide - ``None`` to consider all examples. + -h --help show this help message and exit + --version show version and exit + +The configuration files should have the following objects totally: + + ## Required objects: + + # checkpoint_dir + checkpoint_dir = 'train' + eval_dir = 'eval' + batch_size = 50 + data, labels = get_data_and_labels() + logits = architecture(data) + + ## Optional objects: + + num_examples + run_once + eval_interval_secs + +Example configuration:: + + import tensorflow as tf + + checkpoint_dir = 'avspoof-simple-cnn-train' + eval_dir = 'avspoof-simple-cnn-eval' + tfrecord_filenames = ['/path/to/dev.tfrecod'] + data_shape = (50, 1024, 1) + data_type = tf.float32 + batch_size = 50 + + from bob.learn.tensorflow.utils.tfrecods import batch_data_and_labels + def get_data_and_labels(): + return batch_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size) + + from bob.pad.voice.architectures.simple_cnn import architecture """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# for bob imports to work properly: +# import pkg_resources so that bob imports work properly: import pkg_resources import os import time -from functools import partial import tensorflow as tf -from docopt import docopt -from bob.io.base import create_directories_safe -from bob.dap.voice.architectures.simple_cnn import architecture -from bob.dap.base.database.tfrecords import example_parser -from bob.dap.base.util.eval import get_global_step, eval_once +from bob.bio.base.utils import read_config_file +from ..utils.eval import get_global_step, eval_once def main(argv=None): - arguments = docopt(__doc__, argv=argv) - print(arguments) - input_shape = eval(arguments['--input-shape']) - tfrecord_filenames = arguments['<eval_tfrecords>'] - eval_dir = arguments['--eval-dir'] - batch_size = eval(arguments['--batch-size']) - run_once = arguments['--run-once'] - eval_interval_secs = eval(arguments['--eval-interval-secs']) - checkpoint_dir = arguments['<checkpoint_dir>'] - num_examples = eval(arguments['--num-examples']) - - create_directories_safe(eval_dir) - with tf.Graph().as_default() as g: - - # Get images and labels + from docopt import docopt + import sys + docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} + version = pkg_resources.require('bob.learn.tensorflow')[0].version + args = docopt(docs, argv=argv, version=version) + config_files = args['<config_files>'] + config = read_config_file(config_files) + + run_once = getattr(config, 'run_once', False) + eval_interval_secs = getattr(config, 'eval_interval_secs', 300) + num_examples = getattr(config, 'num_examples', None) + + with tf.Graph().as_default() as graph: + + # Get data and labels with tf.name_scope('input'): - dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames) - feature = {'train/data': tf.FixedLenFeature([], tf.string), - 'train/label': tf.FixedLenFeature([], tf.int64)} - my_example_parser = partial( - example_parser, feature=feature, data_shape=input_shape) - dataset = dataset.map( - my_example_parser, num_threads=1, output_buffer_size=batch_size) - dataset = dataset.batch(batch_size) - images, labels = dataset.make_one_shot_iterator().get_next() + data, labels = config.get_data_and_labels() # Build a Graph that computes the logits predictions from the # inference model. - logits = architecture(images, mode=tf.estimator.ModeKeys.EVAL) + logits = config.architecture(data) + tf.add_to_collection('logits', logits) # Calculate predictions. top_k_op = tf.nn.in_top_k(logits, labels, 1) @@ -68,22 +93,23 @@ def main(argv=None): saver = tf.train.Saver() # Build the summary operation based on the TF collection of Summaries. summary_op = tf.summary.merge_all() - summary_writer = tf.summary.FileWriter(eval_dir, g) - evaluated_file = os.path.join(eval_dir, 'evaluated') + summary_writer = tf.summary.FileWriter(config.eval_dir, graph) + evaluated_file = os.path.join(config.eval_dir, 'evaluated') while True: evaluated_steps = [] if os.path.exists(evaluated_file): with open(evaluated_file) as f: evaluated_steps = f.read().split() - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + ckpt = tf.train.get_checkpoint_state(config.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: for path in ckpt.all_model_checkpoint_paths: global_step = get_global_step(path) if global_step not in evaluated_steps: - ret_val = eval_once(saver, summary_writer, top_k_op, summary_op, - path, global_step, num_examples, - batch_size) + ret_val = eval_once(saver, summary_writer, top_k_op, + summary_op, path, global_step, + num_examples, + config.batch_size) if ret_val == 0: with open(evaluated_file, 'a') as f: f.write(global_step + '\n') diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py index c757f45c0d12e8b5411866e08677876b166804e8..188910e64f9535c451c4358621c993e29ae18d40 100644 --- a/bob/learn/tensorflow/script/train_generic.py +++ b/bob/learn/tensorflow/script/train_generic.py @@ -1,91 +1,121 @@ #!/usr/bin/env python -"""Trains the VGG-audio network on the AVspoof database. +"""Trains networks using tf.train.MonitoredTrainingSession Usage: - %(prog)s [options] <train_tfrecords>... + %(prog)s [options] <config_files>... + %(prog)s --help + %(prog)s --version + +Arguments: + <config_files> The configuration files. The configuration files are loaded + in order and they need to have several objects inside + totally. See below for explanation. Options: - -h --help Show this screen. - --save-dir PATH [default: /idiap/user/amohammadi/avspoof/specgram/avspoof-simple-cnn-train] - --input-shape N [default: (50, 1024, 1)] - --epochs N [default: None] - --batch-size N [default: 32] - --capacity-samples N The capacity of the queue [default: 10**4/2]. - --learning-rate N The learning rate [default: 0.00001]. - --log-frequency N How often to log results to the console. - [default: 100] + -h --help show this help message and exit + --version show version and exit + +The configuration files should have the following objects totally: + + ## Required objects: + + # checkpoint_dir + checkpoint_dir = 'train' + batch_size + data, labels = get_data_and_labels() + logits = architecture(data) + loss = loss(logits, labels) + train_op = optimizer.minimize(loss, global_step=global_step) + + ## Optional objects: + + log_frequency + max_to_keep + +Example configuration:: + + import tensorflow as tf + + checkpoint_dir = 'avspoof-simple-cnn-train' + tfrecord_filenames = ['/path/to/group.tfrecod'] + data_shape = (50, 1024, 1) + data_type = tf.float32 + batch_size = 32 + epochs = None + learning_rate = 0.00001 + + from bob.learn.tensorflow.utils.tfrecods import shuffle_data_and_labels + def get_data_and_labels(): + return shuffle_data_and_labels(tfrecord_filenames, data_shape, + data_type, batch_size, epochs=epochs) + + from bob.pad.voice.architectures.simple_cnn import architecture + + def loss(logits, labels): + predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + return tf.reduce_mean(predictor) + + optimizer = tf.train.GradientDescentOptimizer(learning_rate) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# for bob imports to work properly: +# import pkg_resources so that bob imports work properly: import pkg_resources # for creating reproducible nets -from bob.dap.base.sandbox.reproducible import session_conf +from ..utils.reproducible import session_conf import tensorflow as tf -from docopt import docopt -from bob.io.base import create_directories_safe -from bob.dap.voice.architectures.simple_cnn import architecture -from bob.dap.base.database.tfrecords import read_and_decode -from bob.dap.base.util.hooks import LoggerHook +from bob.bio.base.utils import read_config_file +from ..utils.hooks import LoggerHook def main(argv=None): - arguments = docopt(__doc__, argv=argv) - print(arguments) - input_shape = eval(arguments['--input-shape']) - tfrecord_filenames = arguments['<train_tfrecords>'] - save_dir = arguments['--save-dir'] - epochs = eval(arguments['--epochs']) - batch_size = eval(arguments['--batch-size']) - capacity_samples = eval(arguments['--capacity-samples']) - learning_rate = eval(arguments['--learning-rate']) - log_frequency = eval(arguments['--log-frequency']) - - create_directories_safe(save_dir) + from docopt import docopt + import os + import sys + docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} + version = pkg_resources.require('bob.learn.tensorflow')[0].version + args = docopt(docs, argv=argv, version=version) + config_files = args['<config_files>'] + config = read_config_file(config_files) + + max_to_keep = getattr(config, 'max_to_keep', 10**5) + log_frequency = getattr(config, 'log_frequency', 100) + with tf.Graph().as_default(): global_step = tf.contrib.framework.get_or_create_global_step() - # Get images and labels + # Get data and labels with tf.name_scope('input'): - filename_queue = tf.train.string_input_producer( - tfrecord_filenames, num_epochs=epochs, name="tfrecord_filenames") - - image, label = read_and_decode(filename_queue, input_shape) - images, labels = tf.train.shuffle_batch( - [image, label], batch_size=batch_size, - capacity=capacity_samples // batch_size, - min_after_dequeue=int(capacity_samples // batch_size // 2), - num_threads=1, name="shuffle_batch") + data, labels = config.get_data_and_labels() # Build a Graph that computes the logits predictions from the # inference model. - logits = architecture(images) + logits = config.architecture(data) tf.add_to_collection('logits', logits) # Calculate loss. - predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=labels) - loss = tf.reduce_mean(predictor) + loss = config.loss(logits=logits, labels=labels) tf.summary.scalar('loss', loss) - # Build a Graph that trains the model with one batch of examples and - # updates the model parameters. - optimizer = tf.train.GradientDescentOptimizer(learning_rate) - train_op = optimizer.minimize(loss, global_step=global_step) + # get training operation using optimizer: + train_op = config.optimizer.minimize(loss, global_step=global_step) - saver = tf.train.Saver(max_to_keep=10**5) + saver = tf.train.Saver(max_to_keep=max_to_keep) scaffold = tf.train.Scaffold(saver=saver) + with tf.train.MonitoredTrainingSession( - checkpoint_dir=save_dir, + checkpoint_dir=config.checkpoint_dir, scaffold=scaffold, hooks=[ - tf.train.CheckpointSaverHook( - save_dir, save_secs=60 * 29, scaffold=scaffold), + tf.train.CheckpointSaverHook(config.checkpoint_dir, + save_secs=60 * 29, + scaffold=scaffold), tf.train.NanTensorHook(loss), - LoggerHook(loss, batch_size, log_frequency)], + LoggerHook(loss, config.batch_size, log_frequency)], config=session_conf, save_checkpoint_secs=None, save_summaries_steps=100, diff --git a/bob/learn/tensorflow/test/test_db_to_tfrecords.py b/bob/learn/tensorflow/test/test_db_to_tfrecords.py index e241159638e956bb8e1e8f4cbed32b6e78b2d032..64e9804c1366679a922d821e0a4a993106ab5579 100755 --- a/bob/learn/tensorflow/test/test_db_to_tfrecords.py +++ b/bob/learn/tensorflow/test/test_db_to_tfrecords.py @@ -19,7 +19,7 @@ def test_verify_and_tfrecords(): with open(dummy_config) as f, open(config_path, 'w') as f2: f2.write(f.read().replace('TEST_DIR', test_dir)) - parameters = [os.path.join(config_path)] + parameters = [config_path] try: verify(parameters) tfrecords(parameters) diff --git a/bob/learn/tensorflow/test/test_eval.py b/bob/learn/tensorflow/test/test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2441ac53402af1fd433d76b13250d6ff4a180b9e --- /dev/null +++ b/bob/learn/tensorflow/test/test_eval.py @@ -0,0 +1,182 @@ +from __future__ import print_function +import warnings as _warnings +import sys as _sys +import os +from tempfile import mkdtemp +import logging +logging.getLogger("tensorflow").setLevel(logging.WARNING) +import tensorflow as tf +from bob.io.base.test_utils import datafile +from bob.io.base import create_directories_safe + +from bob.learn.tensorflow.utils.eval import eval_once +from bob.learn.tensorflow.utils.tfrecords import batch_data_and_labels +from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords +from bob.bio.base.script.verify import main as verify +from bob.learn.tensorflow.script.train_generic import main as train_generic + +dummy_config = datafile('dummy_verify_config.py', __name__) +DATA_SAHPE = (1, 112, 92) + + +# from https://stackoverflow.com/a/19299884 +class TemporaryDirectory(object): + """Create and return a temporary directory. This has the same + behavior as mkdtemp but can be used as a context manager. For + example: + + with TemporaryDirectory() as tmpdir: + ... + + Upon exiting the context, the directory and everything contained + in it are removed. + """ + + def __init__(self, suffix="", prefix="tmp", dir=None): + self._closed = False + self.name = None # Handle mkdtemp raising an exception + self.name = mkdtemp(suffix, prefix, dir) + + def __repr__(self): + return "<{} {!r}>".format(self.__class__.__name__, self.name) + + def __enter__(self): + return self.name + + def cleanup(self, _warn=False): + if self.name and not self._closed: + try: + self._rmtree(self.name) + except (TypeError, AttributeError) as ex: + # Issue #10188: Emit a warning on stderr + # if the directory could not be cleaned + # up due to missing globals + if "None" not in str(ex): + raise + print("ERROR: {!r} while cleaning up {!r}".format(ex, self,), + file=_sys.stderr) + return + self._closed = True + if _warn: + self._warn("Implicitly cleaning up {!r}".format(self), + ResourceWarning) + + def __exit__(self, exc, value, tb): + self.cleanup() + + def __del__(self): + # Issue a ResourceWarning if implicit cleanup needed + self.cleanup(_warn=True) + + # XXX (ncoghlan): The following code attempts to make + # this class tolerant of the module nulling out process + # that happens during CPython interpreter shutdown + # Alas, it doesn't actually manage it. See issue #10188 + _listdir = staticmethod(os.listdir) + _path_join = staticmethod(os.path.join) + _isdir = staticmethod(os.path.isdir) + _islink = staticmethod(os.path.islink) + _remove = staticmethod(os.remove) + _rmdir = staticmethod(os.rmdir) + _warn = _warnings.warn + + def _rmtree(self, path): + # Essentially a stripped down version of shutil.rmtree. We can't + # use globals because they may be None'ed out at shutdown. + for name in self._listdir(path): + fullname = self._path_join(path, name) + try: + isdir = self._isdir(fullname) and not self._islink(fullname) + except OSError: + isdir = False + if isdir: + self._rmtree(fullname) + else: + try: + self._remove(fullname) + except OSError: + pass + try: + self._rmdir(path) + except OSError: + pass + + +def _create_tfrecord(test_dir): + config_path = os.path.join(test_dir, 'tfrecordconfig.py') + with open(dummy_config) as f, open(config_path, 'w') as f2: + f2.write(f.read().replace('TEST_DIR', test_dir)) + verify([config_path]) + tfrecords([config_path]) + return os.path.join(test_dir, 'sub_directory', 'dev.tfrecords') + + +def _create_checkpoint(checkpoint_dir, dummy_tfrecord): + config = ''' +import tensorflow as tf + +checkpoint_dir = "{}" +tfrecord_filenames = ['{}'] +data_shape = {} +data_type = tf.uint8 +batch_size = 32 +epochs = 1 +learning_rate = 0.00001 + +from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels +def get_data_and_labels(): + return shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=epochs) + +def architecture(images): + images = tf.cast(images, tf.float32) + logits = tf.reshape(images, [-1, 92 * 112]) + logits = tf.layers.dense(inputs=logits, units=20, + activation=tf.nn.relu) + return logits + +def loss(logits, labels): + predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + return tf.reduce_mean(predictor) + +optimizer = tf.train.GradientDescentOptimizer(learning_rate) +'''.format(checkpoint_dir, dummy_tfrecord, DATA_SAHPE) + create_directories_safe(checkpoint_dir) + config_path = os.path.join(checkpoint_dir, 'config.py') + with open(config_path, 'w') as f: + f.write(config) + train_generic([config_path]) + + +def test_eval_once(): + batch_size = 2 + with tf.Graph().as_default() as graph, TemporaryDirectory() as tmpdir: + + checkpoint_dir = os.path.join(tmpdir, 'checkpoint_dir') + dummy_tfrecord = _create_tfrecord(tmpdir) + _create_checkpoint(checkpoint_dir, dummy_tfrecord) + + # Get images and labels + with tf.name_scope('input'): + images, labels = batch_data_and_labels( + [dummy_tfrecord], DATA_SAHPE, tf.uint8, batch_size, epochs=1) + + images = tf.cast(images, tf.float32) + logits = tf.reshape(images, [-1, 92 * 112]) + logits = tf.layers.dense(inputs=logits, units=20, + activation=tf.nn.relu) + + # Calculate predictions. + prediction_op = tf.nn.in_top_k(logits, labels, 1) + + saver = tf.train.Saver() + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.summary.merge_all() + summary_writer = tf.summary.FileWriter(tmpdir, graph) + + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) + path = ckpt.model_checkpoint_path + ret_val = eval_once(saver, summary_writer, prediction_op, summary_op, + path, '1', None, batch_size) + assert ret_val == 0, str(ret_val) diff --git a/bob/learn/tensorflow/utils/__init__.py b/bob/learn/tensorflow/utils/__init__.py index e73a1da73733affd340604379e89fbecf765fac0..3fe013e8ccb40c8512359a1774c63f0513e18075 100755 --- a/bob/learn/tensorflow/utils/__init__.py +++ b/bob/learn/tensorflow/utils/__init__.py @@ -1,3 +1,6 @@ from .util import * from .singleton import Singleton -from .session import Session \ No newline at end of file +from .session import Session +from . import hooks +from . import eval +from . import tfrecords diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py index ce20ed13c374fbe3fc03aa6dc2ea26115e3945a7..94f9ca43dfd4da1d6dc97550ff5972a2b7a36e42 100644 --- a/bob/learn/tensorflow/utils/eval.py +++ b/bob/learn/tensorflow/utils/eval.py @@ -30,6 +30,19 @@ def get_global_step(path): return global_step +def _log_precision(true_count, total_sample_count, global_step, sess, + summary_writer, summary_op): + # Compute precision @ 1. + precision = true_count / total_sample_count + print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) + + summary = tf.Summary() + summary.ParseFromString(sess.run(summary_op)) + summary.value.add(tag='Precision @ 1', simple_value=precision) + summary_writer.add_summary(summary, global_step) + return 0 + + def eval_once(saver, summary_writer, prediction_op, summary_op, model_checkpoint_path, global_step, num_examples, batch_size): """Run Eval once. @@ -59,7 +72,6 @@ def eval_once(saver, summary_writer, prediction_op, summary_op, ------- int 0 for success, anything else for fail. - """ with tf.Session() as sess: sess.run(tf.local_variables_initializer()) @@ -77,8 +89,8 @@ def eval_once(saver, summary_writer, prediction_op, summary_op, try: threads = [] for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): - threads.extend(qr.create_threads(sess, coord=coord, daemon=True, - start=True)) + threads.extend(qr.create_threads(sess, coord=coord, + daemon=True, start=True)) if num_examples is None: num_iter = float("inf") @@ -93,15 +105,14 @@ def eval_once(saver, summary_writer, prediction_op, summary_op, total_sample_count += np.asarray(predictions).size step += 1 - # Compute precision @ 1. - precision = true_count / total_sample_count - print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) - - summary = tf.Summary() - summary.ParseFromString(sess.run(summary_op)) - summary.value.add(tag='Precision @ 1', simple_value=precision) - summary_writer.add_summary(summary, global_step) - return 0 + return _log_precision(true_count, total_sample_count, + global_step, sess, summary_writer, + summary_op) + except tf.errors.OutOfRangeError as e: + coord.request_stop(e) + return _log_precision(true_count, total_sample_count, + global_step, sess, summary_writer, + summary_op) except Exception as e: # pylint: disable=broad-except coord.request_stop(e) return -1 diff --git a/bob/learn/tensorflow/utils/tfrecords.py b/bob/learn/tensorflow/utils/tfrecords.py index dcab80c86d940a89534dee4a02aeefe9b5493de6..b8c73553806fbda0e4dfaf53536399e1544a7909 100644 --- a/bob/learn/tensorflow/utils/tfrecords.py +++ b/bob/learn/tensorflow/utils/tfrecords.py @@ -1,12 +1,12 @@ import tensorflow as tf -def example_parser(serialized_example, feature, data_shape): +def example_parser(serialized_example, feature, data_shape, data_type): """Parses a single tf.Example into image and label tensors.""" # Decode the record read by the reader features = tf.parse_single_example(serialized_example, features=feature) # Convert the image data from string back to the numbers - image = tf.decode_raw(features['train/data'], tf.float32) + image = tf.decode_raw(features['train/data'], data_type) # Cast label data into int64 label = tf.cast(features['train/label'], tf.int64) # Reshape image data into the original shape @@ -14,12 +14,50 @@ def example_parser(serialized_example, feature, data_shape): return image, label -def read_and_decode(filename_queue, data_shape, feature=None): - +def read_and_decode(filename_queue, data_shape, data_type=tf.float32, + feature=None): if feature is None: feature = {'train/data': tf.FixedLenFeature([], tf.string), 'train/label': tf.FixedLenFeature([], tf.int64)} # Define a reader and read the next record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) - return example_parser(serialized_example, feature, data_shape) + return example_parser(serialized_example, feature, data_shape, data_type) + + +def _read_data_and_labesl(tfrecord_filenames, data_shape, data_type, + epochs=None): + + filename_queue = tf.train.string_input_producer( + tfrecord_filenames, num_epochs=epochs, name="tfrecord_filenames") + + data, label = read_and_decode(filename_queue, data_shape, data_type) + return data, label + + +def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=None, capacity=10**3, + min_after_dequeue=None, num_threads=1): + if min_after_dequeue is None: + min_after_dequeue = capacity // 2 + data, label = _read_data_and_labesl( + tfrecord_filenames, data_shape, data_type, epochs) + + datas, labels = tf.train.shuffle_batch( + [data, label], batch_size=batch_size, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + num_threads=num_threads, name="shuffle_batch") + return datas, labels + + +def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=1, capacity=10**3, num_threads=1): + data, label = _read_data_and_labesl( + tfrecord_filenames, data_shape, data_type, epochs) + + datas, labels = tf.train.batch( + [data, label], batch_size=batch_size, + capacity=capacity, + num_threads=num_threads, name="batch") + return datas, labels diff --git a/setup.py b/setup.py index 493222372fc3ca1dcd2f34db879ebd3277e7bd4a..6a0b4dbdcf9cca07279ef5a0157348188e85d9cf 100755 --- a/setup.py +++ b/setup.py @@ -51,7 +51,9 @@ setup( 'train.py = bob.learn.tensorflow.script.train:main', 'bob_db_to_tfrecords = bob.learn.tensorflow.script.db_to_tfrecords:main', 'load_and_debug.py = bob.learn.tensorflow.script.load_and_debug:main', - 'lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main' + 'lfw_db_to_tfrecords.py = bob.learn.tensorflow.script.lfw_db_to_tfrecords:main', + 'bob_tf_train_generic = bob.learn.tensorflow.script.train_generic:main', + 'bob_tf_eval_generic = bob.learn.tensorflow.script.eval_generic:main', ], },