From 44e40482f03d6e8ee742b45cda7bee67c9392455 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 23 May 2018 15:18:47 +0200 Subject: [PATCH] Implement model saving in bob tf eval. Fixes #54 --- bob/learn/tensorflow/script/eval.py | 105 ++++++++++++++---- .../tensorflow/test/test_estimator_scripts.py | 37 +++++- 2 files changed, 121 insertions(+), 21 deletions(-) diff --git a/bob/learn/tensorflow/script/eval.py b/bob/learn/tensorflow/script/eval.py index 4a80f24a..1d106fff 100644 --- a/bob/learn/tensorflow/script/eval.py +++ b/bob/learn/tensorflow/script/eval.py @@ -4,20 +4,86 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import click import logging import os -import time import six +import shutil import sys import tensorflow as tf +import time +from glob import glob +from collections import defaultdict from ..utils.eval import get_global_step -import click from bob.extension.scripts.click_helper import ( verbosity_option, ConfigCommand, ResourceOption) +from bob.io.base import create_directories_safe logger = logging.getLogger(__name__) +def save_best_n_models(train_dir, save_dir, evaluated_file, + keep_n_best_models): + create_directories_safe(save_dir) + evaluated = read_evaluated_file(evaluated_file) + + def _key(x): + x = x[1] + ac = x.get('accuracy') + lo = x.get('loss') or 0 + return ac * -1 if ac is not None else lo + + best_models = dict(sorted( + evaluated.items(), key=_key)[:keep_n_best_models]) + + # delete the old saved models that are not in top N best anymore + saved_models = defaultdict(list) + for path in glob('{}/model.ckpt-*'.format(save_dir)): + global_step = path.split('model.ckpt-')[1].split('.')[0] + saved_models[global_step].append(path) + + for global_step, paths in saved_models.items(): + if global_step not in best_models: + for path in paths: + logger.info("Deleting `%s'", path) + os.remove(path) + + # copy over the best models if not already there + for global_step in best_models: + for path in glob('{}/model.ckpt-{}.*'.format(train_dir, global_step)): + dst = os.path.join(save_dir, os.path.basename(path)) + if os.path.isfile(dst): + continue + logger.info("Copying `%s' over to `%s'", path, dst) + shutil.copy(path, dst) + + +def read_evaluated_file(path): + evaluated = {} + with open(path) as f: + for line in f: + global_step, line = line.split(' ', 1) + temp = {} + for k_v in line.strip().split(', '): + k, v = k_v.split(' = ') + v = float(v) + if 'global_step' in k: + v = int(v) + temp[k] = v + evaluated[global_step] = temp + return evaluated + + +def append_evaluated_file(path, evaluations): + str_evaluations = ', '.join( + '%s = %s' % (k, v) + for k, v in sorted(six.iteritems(evaluations))) + with open(path, 'a') as f: + f.write('{} {}\n'.format(evaluations['global_step'], + str_evaluations)) + return str_evaluations + + @click.command(entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) @click.option('--estimator', '-e', required=True, cls=ResourceOption, @@ -28,12 +94,14 @@ logger = logging.getLogger(__name__) entry_point_group='bob.learn.tensorflow.hook') @click.option('--run-once', cls=ResourceOption, default=False, show_default=True) -@click.option('--eval-interval-secs', cls=ResourceOption, type=click.types.INT, +@click.option('--eval-interval-secs', cls=ResourceOption, type=click.INT, default=60, show_default=True) @click.option('--name', cls=ResourceOption) +@click.option('--keep-n-best-models', '-K', type=click.INT, cls=ResourceOption, + default=0, show_default=True) @verbosity_option(cls=ResourceOption) def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, - **kwargs): + keep_n_best_models, **kwargs): """Evaluates networks using Tensorflow estimators. \b @@ -76,18 +144,16 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, logger.debug('run_once: %s', run_once) logger.debug('eval_interval_secs: %s', eval_interval_secs) logger.debug('name: %s', name) + logger.debug('keep_n_best_models: %s', keep_n_best_models) logger.debug('kwargs: %s', kwargs) - if name: - real_name = 'eval_' + name - else: - real_name = 'eval' - evaluated_file = os.path.join(estimator.model_dir, real_name, 'evaluated') + real_name = 'eval_' + name if name else 'eval' + eval_dir = os.path.join(estimator.model_dir, real_name) + evaluated_file = os.path.join(eval_dir, 'evaluated') while True: - evaluated_steps = [] + evaluated_steps = {} if os.path.exists(evaluated_file): - with open(evaluated_file) as f: - evaluated_steps = [line.split()[0] for line in f] + evaluated_steps = read_evaluated_file(evaluated_file) ckpt = tf.train.get_checkpoint_state(estimator.model_dir) if (not ckpt) or (not ckpt.model_checkpoint_path): @@ -113,14 +179,15 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, name=name, ) - str_evaluations = ', '.join( - '%s = %s' % (k, v) - for k, v in sorted(six.iteritems(evaluations))) - print(str_evaluations) + str_evaluations = append_evaluated_file( + evaluated_file, evaluations) + click.echo(str_evaluations) sys.stdout.flush() - with open(evaluated_file, 'a') as f: - f.write('{} {}\n'.format(evaluations['global_step'], - str_evaluations)) + + # Save the best N models into the eval directory + save_best_n_models(estimator.model_dir, eval_dir, evaluated_file, + keep_n_best_models) + if run_once: break time.sleep(eval_interval_secs) diff --git a/bob/learn/tensorflow/test/test_estimator_scripts.py b/bob/learn/tensorflow/test/test_estimator_scripts.py index 7e69f8de..625c6962 100644 --- a/bob/learn/tensorflow/test/test_estimator_scripts.py +++ b/bob/learn/tensorflow/test/test_estimator_scripts.py @@ -1,6 +1,7 @@ from __future__ import print_function import os import shutil +from glob import glob from tempfile import mkdtemp from click.testing import CliRunner from bob.io.base.test_utils import datafile @@ -122,7 +123,7 @@ def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord): result.exc_info, result.output, result.exception) -def _eval(tmpdir, model_dir, dummy_tfrecord): +def _eval(tmpdir, model_dir, dummy_tfrecord, extra_args=[]): config = CONFIG % { 'model_dir': model_dir, 'tfrecord_filenames': dummy_tfrecord @@ -131,7 +132,7 @@ def _eval(tmpdir, model_dir, dummy_tfrecord): with open(config_path, 'w') as f: f.write(config) runner = CliRunner() - result = runner.invoke(eval_script, args=[config_path]) + result = runner.invoke(eval_script, args=[config_path] + extra_args) assert result.exit_code == 0, '%s\n%s\n%s' % ( result.exc_info, result.output, result.exception) @@ -179,3 +180,35 @@ def test_eval(): shutil.rmtree(tmpdir) except Exception: pass + + +def test_eval_keep_n_model(): + tmpdir = mkdtemp(prefix='bob_') + try: + model_dir = os.path.join(tmpdir, 'model_dir') + eval_dir = os.path.join(model_dir, 'eval') + + print('\nCreating a dummy tfrecord') + dummy_tfrecord = _create_tfrecord(tmpdir) + + print('Training a dummy network') + _create_checkpoint(tmpdir, model_dir, dummy_tfrecord) + + print('Evaluating a dummy network') + _eval(tmpdir, model_dir, dummy_tfrecord, ['-K', '1']) + + evaluated_path = os.path.join(eval_dir, 'evaluated') + assert os.path.exists(evaluated_path), evaluated_path + with open(evaluated_path) as f: + doc = f.read() + + assert '1 ' in doc, doc + assert '200 ' in doc, doc + assert len(glob('{}/model.ckpt-*'.format(eval_dir))) == 3, \ + os.listdir(eval_dir) + + finally: + try: + shutil.rmtree(tmpdir) + except Exception: + pass -- GitLab