diff --git a/bob/learn/tensorflow/script/eval_generic.py b/bob/learn/tensorflow/script/eval_generic.py index 1a78e585b347276b852d16a941711f73f0b4ef53..4c909635f7aa0e038fb1ec93291e2c62e5545ae2 100644 --- a/bob/learn/tensorflow/script/eval_generic.py +++ b/bob/learn/tensorflow/script/eval_generic.py @@ -20,36 +20,77 @@ The configuration files should have the following objects totally: ## Required objects: - # checkpoint_dir - checkpoint_dir = 'train' - eval_dir = 'eval' - batch_size = 50 - data, labels = get_data_and_labels() - logits = architecture(data) + model_fn + eval_input_fn ## Optional objects: - num_examples - run_once eval_interval_secs + run_once + model_dir + run_config + model_params + steps + hooks + name Example configuration:: import tensorflow as tf - - checkpoint_dir = 'avspoof-simple-cnn-train' - eval_dir = 'avspoof-simple-cnn-eval' - tfrecord_filenames = ['/path/to/dev.tfrecods'] - data_shape = (50, 1024, 1) - data_type = tf.float32 - batch_size = 50 - from bob.learn.tensorflow.utils.tfrecords import batch_data_and_labels - def get_data_and_labels(): - return batch_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size) - from bob.pad.voice.architectures.simple_cnn import architecture + model_dir = "%(model_dir)s" + tfrecord_filenames = ['%(tfrecord_filenames)s'] + data_shape = (1, 112, 92) # size of atnt images + data_type = tf.uint8 + batch_size = 2 + epochs = 1 + run_once = True + + def eval_input_fn(): + return batch_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=epochs) + + def architecture(images): + images = tf.cast(images, tf.float32) + logits = tf.reshape(images, [-1, 92 * 112]) + logits = tf.layers.dense(inputs=logits, units=20, + activation=tf.nn.relu) + return logits + + def model_fn(features, labels, mode, params, config): + logits = architecture(features) + + predictions = { + # Generate predictions (for PREDICT and EVAL mode) + "classes": tf.argmax(input=logits, axis=1), + # Add `softmax_tensor` to the graph. It is used for PREDICT and by + # the `logging_hook`. + "probabilities": tf.nn.softmax(logits, name="softmax_tensor") + } + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, + predictions=predictions) + + # Calculate Loss (for both TRAIN and EVAL modes) + predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + loss = tf.reduce_mean(predictor) + + # Configure the Training Op (for TRAIN mode) + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.contrib.framework.get_or_create_global_step() + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, + train_op=train_op) + + # Add evaluation metrics (for EVAL mode) + eval_metric_ops = { + "accuracy": tf.metrics.accuracy( + labels=labels, predictions=predictions["classes"])} + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) """ from __future__ import absolute_import from __future__ import division @@ -58,9 +99,10 @@ from __future__ import print_function import pkg_resources import os import time +import six import tensorflow as tf from bob.bio.base.utils import read_config_file -from ..utils.eval import get_global_step, eval_once +from ..utils.eval import get_global_step def main(argv=None): @@ -72,48 +114,56 @@ def main(argv=None): config_files = args['<config_files>'] config = read_config_file(config_files) - run_once = getattr(config, 'run_once', False) + model_fn = config.model_fn + eval_input_fn = config.eval_input_fn + eval_interval_secs = getattr(config, 'eval_interval_secs', 300) - num_examples = getattr(config, 'num_examples', None) - - with tf.Graph().as_default() as graph: - - # Get data and labels - with tf.name_scope('input'): - data, labels = config.get_data_and_labels() - - # Build a Graph that computes the logits predictions from the - # inference model. - logits = config.architecture(data) - tf.add_to_collection('logits', logits) - - # Calculate predictions. - top_k_op = tf.nn.in_top_k(logits, labels, 1) - - saver = tf.train.Saver() - summary_writer = tf.summary.FileWriter(config.eval_dir, graph) - evaluated_file = os.path.join(config.eval_dir, 'evaluated') - - while True: - evaluated_steps = [] - if os.path.exists(evaluated_file): - with open(evaluated_file) as f: - evaluated_steps = f.read().split() - ckpt = tf.train.get_checkpoint_state(config.checkpoint_dir) - if ckpt and ckpt.model_checkpoint_path: - for path in ckpt.all_model_checkpoint_paths: - global_step = get_global_step(path) - if global_step not in evaluated_steps: - ret_val = eval_once(saver, summary_writer, top_k_op, - path, global_step, - num_examples, - config.batch_size) - if ret_val == 0: - with open(evaluated_file, 'a') as f: - f.write(global_step + '\n') - if run_once: - break + run_once = getattr(config, 'run_once', False) + model_dir = getattr(config, 'model_dir', None) + run_config = getattr(config, 'run_config', None) + model_params = getattr(config, 'model_params', None) + steps = getattr(config, 'steps', None) + hooks = getattr(config, 'hooks', None) + name = getattr(config, 'eval_name', None) + + # Instantiate Estimator + nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, + params=model_params, config=run_config) + + evaluated_file = os.path.join(nn.model_dir, name or 'eval', 'evaluated') + while True: + evaluated_steps = [] + if os.path.exists(evaluated_file): + with open(evaluated_file) as f: + evaluated_steps = f.read().split() + + ckpt = tf.train.get_checkpoint_state(nn.model_dir) + if (not ckpt) or (not ckpt.model_checkpoint_path): time.sleep(eval_interval_secs) + continue + + for checkpoint_path in ckpt.all_model_checkpoint_paths: + global_step = str(get_global_step(checkpoint_path)) + if global_step in evaluated_steps: + continue + + # Evaluate + evaluations = nn.evaluate( + input_fn=eval_input_fn, + steps=steps, + hooks=hooks, + checkpoint_path=checkpoint_path, + name=name, + ) + + print(', '.join('%s = %s' % (k, v) + for k, v in sorted(six.iteritems(evaluations)))) + sys.stdout.flush() + with open(evaluated_file, 'a') as f: + f.write('{}\n'.format(evaluations['global_step'])) + if run_once: + break + time.sleep(eval_interval_secs) if __name__ == '__main__': diff --git a/bob/learn/tensorflow/script/train_generic.py b/bob/learn/tensorflow/script/train_generic.py index 188910e64f9535c451c4358621c993e29ae18d40..bd1f46dda5525b9d631ce96c7bb482479b78ad36 100644 --- a/bob/learn/tensorflow/script/train_generic.py +++ b/bob/learn/tensorflow/script/train_generic.py @@ -20,56 +20,83 @@ The configuration files should have the following objects totally: ## Required objects: - # checkpoint_dir - checkpoint_dir = 'train' - batch_size - data, labels = get_data_and_labels() - logits = architecture(data) - loss = loss(logits, labels) - train_op = optimizer.minimize(loss, global_step=global_step) + model_fn + train_input_fn ## Optional objects: - log_frequency - max_to_keep + model_dir + run_config + model_params + hooks + steps + max_steps Example configuration:: import tensorflow as tf - - checkpoint_dir = 'avspoof-simple-cnn-train' - tfrecord_filenames = ['/path/to/group.tfrecod'] - data_shape = (50, 1024, 1) - data_type = tf.float32 - batch_size = 32 - epochs = None + from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels + + model_dir = "%(model_dir)s" + tfrecord_filenames = ['%(tfrecord_filenames)s'] + data_shape = (1, 112, 92) # size of atnt images + data_type = tf.uint8 + batch_size = 2 + epochs = 1 learning_rate = 0.00001 - from bob.learn.tensorflow.utils.tfrecods import shuffle_data_and_labels - def get_data_and_labels(): + def train_input_fn(): return shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, batch_size, epochs=epochs) - from bob.pad.voice.architectures.simple_cnn import architecture - - def loss(logits, labels): + def architecture(images): + images = tf.cast(images, tf.float32) + logits = tf.reshape(images, [-1, 92 * 112]) + logits = tf.layers.dense(inputs=logits, units=20, + activation=tf.nn.relu) + return logits + + def model_fn(features, labels, mode, params, config): + logits = architecture(features) + + predictions = { + # Generate predictions (for PREDICT and EVAL mode) + "classes": tf.argmax(input=logits, axis=1), + # Add `softmax_tensor` to the graph. It is used for PREDICT and by + # the `logging_hook`. + "probabilities": tf.nn.softmax(logits, name="softmax_tensor") + } + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, + predictions=predictions) + + # Calculate Loss (for both TRAIN and EVAL modes) predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels) - return tf.reduce_mean(predictor) - - optimizer = tf.train.GradientDescentOptimizer(learning_rate) + loss = tf.reduce_mean(predictor) + + # Configure the Training Op (for TRAIN mode) + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.contrib.framework.get_or_create_global_step() + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, + train_op=train_op) + + # Add evaluation metrics (for EVAL mode) + eval_metric_ops = { + "accuracy": tf.metrics.accuracy( + labels=labels, predictions=predictions["classes"])} + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function # import pkg_resources so that bob imports work properly: import pkg_resources -# for creating reproducible nets -from ..utils.reproducible import session_conf - import tensorflow as tf from bob.bio.base.utils import read_config_file -from ..utils.hooks import LoggerHook def main(argv=None): @@ -82,46 +109,29 @@ def main(argv=None): config_files = args['<config_files>'] config = read_config_file(config_files) - max_to_keep = getattr(config, 'max_to_keep', 10**5) - log_frequency = getattr(config, 'log_frequency', 100) - - with tf.Graph().as_default(): - global_step = tf.contrib.framework.get_or_create_global_step() - - # Get data and labels - with tf.name_scope('input'): - data, labels = config.get_data_and_labels() - - # Build a Graph that computes the logits predictions from the - # inference model. - logits = config.architecture(data) - tf.add_to_collection('logits', logits) - - # Calculate loss. - loss = config.loss(logits=logits, labels=labels) - tf.summary.scalar('loss', loss) - - # get training operation using optimizer: - train_op = config.optimizer.minimize(loss, global_step=global_step) - - saver = tf.train.Saver(max_to_keep=max_to_keep) - scaffold = tf.train.Scaffold(saver=saver) - - with tf.train.MonitoredTrainingSession( - checkpoint_dir=config.checkpoint_dir, - scaffold=scaffold, - hooks=[ - tf.train.CheckpointSaverHook(config.checkpoint_dir, - save_secs=60 * 29, - scaffold=scaffold), - tf.train.NanTensorHook(loss), - LoggerHook(loss, config.batch_size, log_frequency)], - config=session_conf, - save_checkpoint_secs=None, - save_summaries_steps=100, - ) as mon_sess: - while not mon_sess.should_stop(): - mon_sess.run(train_op) + model_fn = config.model_fn + train_input_fn = config.train_input_fn + + model_dir = getattr(config, 'model_dir', None) + run_config = getattr(config, 'run_config', None) + model_params = getattr(config, 'model_params', None) + hooks = getattr(config, 'hooks', None) + steps = getattr(config, 'steps', None) + max_steps = getattr(config, 'max_steps', None) + + if run_config is None: + # by default create reproducible nets: + from bob.learn.tensorflow.utils.reproducible import session_conf + run_config = tf.estimator.RunConfig() + run_config.replace(session_config=session_conf) + + # Instantiate Estimator + nn = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, + params=model_params, config=run_config) + + # Train + nn.train(input_fn=train_input_fn, hooks=hooks, steps=steps, + max_steps=max_steps) if __name__ == '__main__': diff --git a/bob/learn/tensorflow/test/test_estimator_scripts.py b/bob/learn/tensorflow/test/test_estimator_scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..38a8d89031f2da2eef92f3cbc73c835c52f2c083 --- /dev/null +++ b/bob/learn/tensorflow/test/test_estimator_scripts.py @@ -0,0 +1,134 @@ +from __future__ import print_function +import os +from tempfile import mkdtemp +import shutil +import logging +logging.getLogger("tensorflow").setLevel(logging.WARNING) +from bob.io.base.test_utils import datafile + +from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords +from bob.bio.base.script.verify import main as verify +from bob.learn.tensorflow.script.train_generic import main as train_generic +from bob.learn.tensorflow.script.eval_generic import main as eval_generic + +dummy_tfrecord_config = datafile('dummy_verify_config.py', __name__) +CONFIG = ''' +import tensorflow as tf +from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels, \ + batch_data_and_labels + +model_dir = "%(model_dir)s" +tfrecord_filenames = ['%(tfrecord_filenames)s'] +data_shape = (1, 112, 92) # size of atnt images +data_type = tf.uint8 +batch_size = 2 +epochs = 1 +learning_rate = 0.00001 +run_once = True + + +def train_input_fn(): + return shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=epochs) + +def eval_input_fn(): + return batch_data_and_labels(tfrecord_filenames, data_shape, data_type, + batch_size, epochs=epochs) + +def architecture(images): + images = tf.cast(images, tf.float32) + logits = tf.reshape(images, [-1, 92 * 112]) + logits = tf.layers.dense(inputs=logits, units=20, + activation=tf.nn.relu) + return logits + + +def model_fn(features, labels, mode, params, config): + logits = architecture(features) + + predictions = { + # Generate predictions (for PREDICT and EVAL mode) + "classes": tf.argmax(input=logits, axis=1), + # Add `softmax_tensor` to the graph. It is used for PREDICT and by the + # `logging_hook`. + "probabilities": tf.nn.softmax(logits, name="softmax_tensor") + } + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) + + # Calculate Loss (for both TRAIN and EVAL modes) + predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + loss = tf.reduce_mean(predictor) + + # Configure the Training Op (for TRAIN mode) + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.contrib.framework.get_or_create_global_step() + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, + train_op=train_op) + + # Add evaluation metrics (for EVAL mode) + eval_metric_ops = { + "accuracy": tf.metrics.accuracy( + labels=labels, predictions=predictions["classes"])} + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) +''' + + +def _create_tfrecord(test_dir): + config_path = os.path.join(test_dir, 'tfrecordconfig.py') + with open(dummy_tfrecord_config) as f, open(config_path, 'w') as f2: + f2.write(f.read().replace('TEST_DIR', test_dir)) + verify([config_path]) + tfrecords([config_path]) + return os.path.join(test_dir, 'sub_directory', 'dev.tfrecords') + + +def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord): + config = CONFIG % {'model_dir': model_dir, + 'tfrecord_filenames': dummy_tfrecord} + config_path = os.path.join(tmpdir, 'train_config.py') + with open(config_path, 'w') as f: + f.write(config) + train_generic([config_path]) + + +def _eval(tmpdir, model_dir, dummy_tfrecord): + config = CONFIG % {'model_dir': model_dir, + 'tfrecord_filenames': dummy_tfrecord} + config_path = os.path.join(tmpdir, 'eval_config.py') + with open(config_path, 'w') as f: + f.write(config) + eval_generic([config_path]) + + +def test_eval_once(): + tmpdir = mkdtemp(prefix='bob_') + try: + model_dir = os.path.join(tmpdir, 'model_dir') + eval_dir = os.path.join(model_dir, 'eval') + + print('\nCreating a dummy tfrecord') + dummy_tfrecord = _create_tfrecord(tmpdir) + + print('Training a dummy network') + _create_checkpoint(tmpdir, model_dir, dummy_tfrecord) + + print('Evaluating a dummy network') + _eval(tmpdir, model_dir, dummy_tfrecord) + + evaluated_path = os.path.join(eval_dir, 'evaluated') + assert os.path.exists(evaluated_path), evaluated_path + with open(evaluated_path) as f: + doc = f.read() + + assert '1' in doc, doc + assert '100' in doc, doc + finally: + try: + shutil.rmtree(tmpdir) + except Exception: + pass diff --git a/bob/learn/tensorflow/test/test_eval.py b/bob/learn/tensorflow/test/test_eval.py deleted file mode 100644 index 8c7b2226e91514b98d68b01cf50acd34592d9469..0000000000000000000000000000000000000000 --- a/bob/learn/tensorflow/test/test_eval.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import print_function -import os -from tempfile import mkdtemp -import shutil -import logging -logging.getLogger("tensorflow").setLevel(logging.WARNING) -from bob.io.base.test_utils import datafile -from bob.io.base import create_directories_safe - -from bob.learn.tensorflow.script.db_to_tfrecords import main as tfrecords -from bob.bio.base.script.verify import main as verify -from bob.learn.tensorflow.script.train_generic import main as train_generic -from bob.learn.tensorflow.script.eval_generic import main as eval_generic - -dummy_config = datafile('dummy_verify_config.py', __name__) -DATA_SAHPE = (1, 112, 92) - - -def _create_tfrecord(test_dir): - config_path = os.path.join(test_dir, 'tfrecordconfig.py') - with open(dummy_config) as f, open(config_path, 'w') as f2: - f2.write(f.read().replace('TEST_DIR', test_dir)) - verify([config_path]) - tfrecords([config_path]) - return os.path.join(test_dir, 'sub_directory', 'dev.tfrecords') - - -def _create_checkpoint(checkpoint_dir, dummy_tfrecord): - config = ''' -import tensorflow as tf - -checkpoint_dir = "{}" -tfrecord_filenames = ['{}'] -data_shape = {} -data_type = tf.uint8 -batch_size = 32 -epochs = 1 -learning_rate = 0.00001 - -from bob.learn.tensorflow.utils.tfrecords import shuffle_data_and_labels -def get_data_and_labels(): - return shuffle_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size, epochs=epochs) - -def architecture(images): - images = tf.cast(images, tf.float32) - logits = tf.reshape(images, [-1, 92 * 112]) - logits = tf.layers.dense(inputs=logits, units=20, - activation=tf.nn.relu) - return logits - -def loss(logits, labels): - predictor = tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=labels) - return tf.reduce_mean(predictor) - -optimizer = tf.train.GradientDescentOptimizer(learning_rate) -'''.format(checkpoint_dir, dummy_tfrecord, DATA_SAHPE) - create_directories_safe(checkpoint_dir) - config_path = os.path.join(checkpoint_dir, 'config.py') - with open(config_path, 'w') as f: - f.write(config) - train_generic([config_path]) - - -def _eval(checkpoint_dir, eval_dir, dummy_tfrecord): - config = ''' -import tensorflow as tf - -checkpoint_dir = '{}' -eval_dir = '{}' -tfrecord_filenames = ['{}'] -data_shape = {} -data_type = tf.uint8 -batch_size = 2 -run_once = True - -from bob.learn.tensorflow.utils.tfrecords import batch_data_and_labels -def get_data_and_labels(): - return batch_data_and_labels(tfrecord_filenames, data_shape, data_type, - batch_size) - -def architecture(images): - images = tf.cast(images, tf.float32) - logits = tf.reshape(images, [-1, 92 * 112]) - logits = tf.layers.dense(inputs=logits, units=20, - activation=tf.nn.relu) - return logits -'''.format(checkpoint_dir, eval_dir, dummy_tfrecord, DATA_SAHPE) - create_directories_safe(eval_dir) - config_path = os.path.join(eval_dir, 'config.py') - with open(config_path, 'w') as f: - f.write(config) - eval_generic([config_path]) - - -def test_eval_once(): - tmpdir = mkdtemp(prefix='bob_') - try: - checkpoint_dir = os.path.join(tmpdir, 'checkpoint_dir') - eval_dir = os.path.join(tmpdir, 'eval_dir') - - print('\nCreating a dummy tfrecord') - dummy_tfrecord = _create_tfrecord(tmpdir) - - print('Training a dummy network') - _create_checkpoint(checkpoint_dir, dummy_tfrecord) - - print('Evaluating a dummy network') - _eval(checkpoint_dir, eval_dir, dummy_tfrecord) - - evaluated_path = os.path.join(eval_dir, 'evaluated') - assert os.path.exists(evaluated_path) - with open(evaluated_path) as f: - doc = f.read() - - assert '1' in doc, doc - assert '7' in doc, doc - finally: - try: - shutil.rmtree(tmpdir) - except Exception: - pass diff --git a/bob/learn/tensorflow/utils/eval.py b/bob/learn/tensorflow/utils/eval.py index edd1cc248bfb9892b72db09dbc05ae242651c388..0fe3c11767526ba70a4f02b1f2e571f0e1a8262f 100644 --- a/bob/learn/tensorflow/utils/eval.py +++ b/bob/learn/tensorflow/utils/eval.py @@ -2,11 +2,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math -import numpy as np -import tensorflow as tf -from datetime import datetime - def get_global_step(path): """Returns the global number associated with the model checkpoint path. The @@ -20,87 +15,10 @@ def get_global_step(path): Returns ------- - global_step : str - The global step number as a string. + global_step : int + The global step number. """ - # Assuming model_checkpoint_path looks something like: - # /my-favorite-path/train/model.ckpt-0, - # extract global_step from it. - global_step = path.split('/')[-1].split('-')[-1] + from tensorflow.python.estimator.estimator import \ + _load_global_step_from_checkpoint_dir + global_step = _load_global_step_from_checkpoint_dir(path) return global_step - - -def _log_precision(true_count, total_sample_count, global_step, sess, - summary_writer): - # Compute precision @ 1. - precision = true_count / total_sample_count - print('%s: precision @ 1 = %.3f (global_step %s)' % - (datetime.now(), precision, global_step)) - - summary = tf.Summary() - summary.value.add(tag='Precision @ 1', simple_value=precision) - summary_writer.add_summary(summary, global_step) - return 0 - - -def eval_once(saver, summary_writer, prediction_op, - model_checkpoint_path, global_step, num_examples, batch_size): - """Run Eval once. - - Parameters - ---------- - saver - Saver. - summary_writer - Summary writer. - prediction_op - Prediction operator. - model_checkpoint_path : str - Path to the model checkpoint. - global_step : str - The global step. - num_examples : int or None - The number of examples to try. - batch_size : int - The size of evaluation batch. - - This function requires the ``from __future__ import division`` import. - - Returns - ------- - int - 0 for success, anything else for fail. - """ - with tf.Session() as sess: - sess.run(tf.local_variables_initializer()) - sess.run(tf.global_variables_initializer()) - - if model_checkpoint_path: - # Restores from checkpoint - saver.restore(sess, model_checkpoint_path) - else: - print('No checkpoint file found') - return -1 - - if num_examples is None: - num_iter = float("inf") - else: - num_iter = int(math.ceil(num_examples / batch_size)) - true_count = 0 # Counts the number of correct predictions. - total_sample_count = 0 - step = 0 - - try: - while step < num_iter: - predictions = sess.run([prediction_op]) - true_count += np.sum(predictions) - total_sample_count += np.asarray(predictions).size - step += 1 - - return _log_precision(true_count, total_sample_count, - global_step, sess, summary_writer) - except tf.errors.OutOfRangeError: - return _log_precision(true_count, total_sample_count, - global_step, sess, summary_writer) - except Exception: - return -1 diff --git a/bob/learn/tensorflow/utils/reproducible.py b/bob/learn/tensorflow/utils/reproducible.py index 87c4ce87dc35488a9db2a3a0233460c70f4916a0..34cb4678258c75d40c889580bb30eff42c8f5242 100644 --- a/bob/learn/tensorflow/utils/reproducible.py +++ b/bob/learn/tensorflow/utils/reproducible.py @@ -2,7 +2,7 @@ import os import numpy as np import tensorflow as tf import random as rn -from tensorflow.contrib import keras +# from tensorflow.contrib import keras # reproducible networks # The below is necessary in Python 3.2.3 onwards to