diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py index 42e1fbdc415fb0efa6c6ad38bd758fc79add2101..1a78e585b347276b852d16a941711f73f0b4ef53 100644 --- a/bob/learn/tensorflow/script/eval_generic.py +++ b/bob/learn/tensorflow/script/eval_generic.py @@ -91,8 +91,6 @@ def main(argv=None): top_k_op = tf.nn.in_top_k(logits, labels, 1) saver = tf.train.Saver() - # Build the summary operation based on the TF collection of Summaries. - summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(config.eval_dir, graph) evaluated_file = os.path.join(config.eval_dir, 'evaluated') @@ -107,7 +105,7 @@ def main(argv=None): global_step = get_global_step(path) if global_step not in evaluated_steps: ret_val = eval_once(saver, summary_writer, top_k_op, - summary_op, path, global_step, + path, global_step, num_examples, config.batch_size) if ret_val == 0: diff --git a/bob/learn/tensorflow/test/test_eval.py b/bob/learn/tensorflow/test/test_eval.py index 2441ac53402af1fd433d76b13250d6ff4a180b9e..8c7b2226e91514b98d68b01cf50acd34592d9469 100644 --- a/bob/learn/tensorflow/test/test_eval.py +++ b/bob/learn/tensorflow/test/test_eval.py @@ -1,107 +1,21 @@ from __future__ import print_function -import warnings as _warnings -import sys as _sys import os from tempfile import mkdtemp +import shutil import logging logging.getLogger("tensorflow").setLevel(logging.WARNING) -import tensorflow as tf from bob.io.base.test_utils import datafile from bob.io.base import create_directories_safe -from bob.learn.tensorflow.utils.eval import eval_once -from bob.learn.tensorflow.utils.tfrecords import batch_data_and_labels from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords from bob.bio.base.script.verify import main as verify from bob.learn.tensorflow.script.train_generic import main as train_generic +from bob.learn.tensorflow.script.eval_generic import main as eval_generic dummy_config = datafile('dummy_verify_config.py', __name__) DATA_SAHPE = (1, 112, 92) -# from https://stackoverflow.com/a/19299884 -class TemporaryDirectory(object): - """Create and return a temporary directory. This has the same - behavior as mkdtemp but can be used as a context manager. For - example: - - with TemporaryDirectory() as tmpdir: - ... - - Upon exiting the context, the directory and everything contained - in it are removed. - """ - - def __init__(self, suffix="", prefix="tmp", dir=None): - self._closed = False - self.name = None # Handle mkdtemp raising an exception - self.name = mkdtemp(suffix, prefix, dir) - - def __repr__(self): - return "<{} {!r}>".format(self.__class__.__name__, self.name) - - def __enter__(self): - return self.name - - def cleanup(self, _warn=False): - if self.name and not self._closed: - try: - self._rmtree(self.name) - except (TypeError, AttributeError) as ex: - # Issue #10188: Emit a warning on stderr - # if the directory could not be cleaned - # up due to missing globals - if "None" not in str(ex): - raise - print("ERROR: {!r} while cleaning up {!r}".format(ex, self,), - file=_sys.stderr) - return - self._closed = True - if _warn: - self._warn("Implicitly cleaning up {!r}".format(self), - ResourceWarning) - - def __exit__(self, exc, value, tb): - self.cleanup() - - def __del__(self): - # Issue a ResourceWarning if implicit cleanup needed - self.cleanup(_warn=True) - - # XXX (ncoghlan): The following code attempts to make - # this class tolerant of the module nulling out process - # that happens during CPython interpreter shutdown - # Alas, it doesn't actually manage it. See issue #10188 - _listdir = staticmethod(os.listdir) - _path_join = staticmethod(os.path.join) - _isdir = staticmethod(os.path.isdir) - _islink = staticmethod(os.path.islink) - _remove = staticmethod(os.remove) - _rmdir = staticmethod(os.rmdir) - _warn = _warnings.warn - - def _rmtree(self, path): - # Essentially a stripped down version of shutil.rmtree. We can't - # use globals because they may be None'ed out at shutdown. - for name in self._listdir(path): - fullname = self._path_join(path, name) - try: - isdir = self._isdir(fullname) and not self._islink(fullname) - except OSError: - isdir = False - if isdir: - self._rmtree(fullname) - else: - try: - self._remove(fullname) - except OSError: - pass - try: - self._rmdir(path) - except OSError: - pass - - def _create_tfrecord(test_dir): config_path = os.path.join(test_dir, 'tfrecordconfig.py') with open(dummy_config) as f, open(config_path, 'w') as f2: @@ -149,34 +63,61 @@ optimizer = tf.train.GradientDescentOptimizer(learning_rate) train_generic([config_path]) -def test_eval_once(): - batch_size = 2 - with tf.Graph().as_default() as graph, TemporaryDirectory() as tmpdir: +def _eval(checkpoint_dir, eval_dir, dummy_tfrecord): + config = ''' +import tensorflow as tf +checkpoint_dir = '{}' +eval_dir = '{}' +tfrecord_filenames = ['{}'] +data_shape = {} +data_type = tf.uint8 +batch_size = 2 +run_once = True + +from bob.learn.tensorflow.utils.tfrecords import batch_data_and_labels +def get_data_and_labels(): + return batch_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size) + +def architecture(images): + images = tf.cast(images, tf.float32) + logits = tf.reshape(images, [-1, 92 * 112]) + logits = tf.layers.dense(inputs=logits, units=20, + activation=tf.nn.relu) + return logits +'''.format(checkpoint_dir, eval_dir, dummy_tfrecord, DATA_SAHPE) + create_directories_safe(eval_dir) + config_path = os.path.join(eval_dir, 'config.py') + with open(config_path, 'w') as f: + f.write(config) + eval_generic([config_path]) + + +def test_eval_once(): + tmpdir = mkdtemp(prefix='bob_') + try: checkpoint_dir = os.path.join(tmpdir, 'checkpoint_dir') + eval_dir = os.path.join(tmpdir, 'eval_dir') + + print('\nCreating a dummy tfrecord') dummy_tfrecord = _create_tfrecord(tmpdir) + + print('Training a dummy network') _create_checkpoint(checkpoint_dir, dummy_tfrecord) - # Get images and labels - with tf.name_scope('input'): - images, labels = batch_data_and_labels( - [dummy_tfrecord], DATA_SAHPE, tf.uint8, batch_size, epochs=1) - - images = tf.cast(images, tf.float32) - logits = tf.reshape(images, [-1, 92 * 112]) - logits = tf.layers.dense(inputs=logits, units=20, - activation=tf.nn.relu) - - # Calculate predictions. - prediction_op = tf.nn.in_top_k(logits, labels, 1) - - saver = tf.train.Saver() - # Build the summary operation based on the TF collection of Summaries. - summary_op = tf.summary.merge_all() - summary_writer = tf.summary.FileWriter(tmpdir, graph) - - ckpt = tf.train.get_checkpoint_state(checkpoint_dir) - path = ckpt.model_checkpoint_path - ret_val = eval_once(saver, summary_writer, prediction_op, summary_op, - path, '1', None, batch_size) - assert ret_val == 0, str(ret_val) + print('Evaluating a dummy network') + _eval(checkpoint_dir, eval_dir, dummy_tfrecord) + + evaluated_path = os.path.join(eval_dir, 'evaluated') + assert os.path.exists(evaluated_path) + with open(evaluated_path) as f: + doc = f.read() + + assert '1' in doc, doc + assert '7' in doc, doc + finally: + try: + shutil.rmtree(tmpdir) + except Exception: + pass diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py index 94f9ca43dfd4da1d6dc97550ff5972a2b7a36e42..edd1cc248bfb9892b72db09dbc05ae242651c388 100644 --- a/bob/learn/tensorflow/utils/eval.py +++ b/bob/learn/tensorflow/utils/eval.py @@ -31,19 +31,19 @@ def get_global_step(path): def _log_precision(true_count, total_sample_count, global_step, sess, - summary_writer, summary_op): + summary_writer): # Compute precision @ 1. precision = true_count / total_sample_count - print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) + print('%s: precision @ 1 = %.3f (global_step %s)' % + (datetime.now(), precision, global_step)) summary = tf.Summary() - summary.ParseFromString(sess.run(summary_op)) summary.value.add(tag='Precision @ 1', simple_value=precision) summary_writer.add_summary(summary, global_step) return 0 -def eval_once(saver, summary_writer, prediction_op, summary_op, +def eval_once(saver, summary_writer, prediction_op, model_checkpoint_path, global_step, num_examples, batch_size): """Run Eval once. @@ -55,8 +55,6 @@ def eval_once(saver, summary_writer, prediction_op, summary_op, Summary writer. prediction_op Prediction operator. - summary_op - Summary operator. model_checkpoint_path : str Path to the model checkpoint. global_step : str @@ -84,38 +82,25 @@ def eval_once(saver, summary_writer, prediction_op, summary_op, print('No checkpoint file found') return -1 - # Start the queue runners. - coord = tf.train.Coordinator() + if num_examples is None: + num_iter = float("inf") + else: + num_iter = int(math.ceil(num_examples / batch_size)) + true_count = 0 # Counts the number of correct predictions. + total_sample_count = 0 + step = 0 + try: - threads = [] - for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): - threads.extend(qr.create_threads(sess, coord=coord, - daemon=True, start=True)) - - if num_examples is None: - num_iter = float("inf") - else: - num_iter = int(math.ceil(num_examples / batch_size)) - true_count = 0 # Counts the number of correct predictions. - total_sample_count = 0 - step = 0 - while step < num_iter and not coord.should_stop(): + while step < num_iter: predictions = sess.run([prediction_op]) true_count += np.sum(predictions) total_sample_count += np.asarray(predictions).size step += 1 return _log_precision(true_count, total_sample_count, - global_step, sess, summary_writer, - summary_op) - except tf.errors.OutOfRangeError as e: - coord.request_stop(e) + global_step, sess, summary_writer) + except tf.errors.OutOfRangeError: return _log_precision(true_count, total_sample_count, - global_step, sess, summary_writer, - summary_op) - except Exception as e: # pylint: disable=broad-except - coord.request_stop(e) + global_step, sess, summary_writer) + except Exception: return -1 - finally: - coord.request_stop() - coord.join(threads, stop_grace_period_secs=10) diff --git a/bob/learn/tensorflow/utils/tfrecords.py b/bob/learn/tensorflow/utils/tfrecords.py index b8c73553806fbda0e4dfaf53536399e1544a7909..48da0740577c2a64e4e6f59b26dac959e0a0678f 100644 --- a/bob/learn/tensorflow/utils/tfrecords.py +++ b/bob/learn/tensorflow/utils/tfrecords.py @@ -1,6 +1,11 @@ +from functools import partial import tensorflow as tf +DEFAULT_FEATURE = {'train/data': tf.FixedLenFeature([], tf.string), + 'train/label': tf.FixedLenFeature([], tf.int64)} + + def example_parser(serialized_example, feature, data_shape, data_type): """Parses a single tf.Example into image and label tensors.""" # Decode the record read by the reader @@ -17,47 +22,39 @@ def example_parser(serialized_example, feature, data_shape, data_type): def read_and_decode(filename_queue, data_shape, data_type=tf.float32, feature=None): if feature is None: - feature = {'train/data': tf.FixedLenFeature([], tf.string), - 'train/label': tf.FixedLenFeature([], tf.int64)} + feature = DEFAULT_FEATURE # Define a reader and read the next record reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) return example_parser(serialized_example, feature, data_shape, data_type) -def _read_data_and_labesl(tfrecord_filenames, data_shape, data_type, - epochs=None): - - filename_queue = tf.train.string_input_producer( - tfrecord_filenames, num_epochs=epochs, name="tfrecord_filenames") - - data, label = read_and_decode(filename_queue, data_shape, data_type) - return data, label +def create_dataset_from_records(tfrecord_filenames, data_shape, data_type, + feature=None): + if feature is None: + feature = DEFAULT_FEATURE + dataset = tf.contrib.data.TFRecordDataset(tfrecord_filenames) + parser = partial(example_parser, feature=feature, data_shape=data_shape, + data_type=data_type) + dataset = dataset.map(parser) + return dataset def shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size, epochs=None, capacity=10**3, - min_after_dequeue=None, num_threads=1): - if min_after_dequeue is None: - min_after_dequeue = capacity // 2 - data, label = _read_data_and_labesl( - tfrecord_filenames, data_shape, data_type, epochs) - - datas, labels = tf.train.shuffle_batch( - [data, label], batch_size=batch_size, - capacity=capacity, - min_after_dequeue=min_after_dequeue, - num_threads=num_threads, name="shuffle_batch") + batch_size, epochs=None, buffer_size=10**3): + dataset = create_dataset_from_records(tfrecord_filenames, data_shape, + data_type) + dataset = dataset.shuffle(buffer_size).batch(batch_size).repeat(epochs) + + datas, labels = dataset.make_one_shot_iterator().get_next() return datas, labels def batch_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size, epochs=1, capacity=10**3, num_threads=1): - data, label = _read_data_and_labesl( - tfrecord_filenames, data_shape, data_type, epochs) - - datas, labels = tf.train.batch( - [data, label], batch_size=batch_size, - capacity=capacity, - num_threads=num_threads, name="batch") + batch_size, epochs=1): + dataset = create_dataset_from_records(tfrecord_filenames, data_shape, + data_type) + dataset = dataset.batch(batch_size).repeat(epochs) + + datas, labels = dataset.make_one_shot_iterator().get_next() return datas, labels