Skip to content
Snippets Groups Projects
Commit 44e40482 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Implement model saving in bob tf eval. Fixes #54

parent 73aa5774
No related branches found
No related tags found
1 merge request!52Implement model saving in bob tf eval. Fixes #54
...@@ -4,20 +4,86 @@ ...@@ -4,20 +4,86 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import click
import logging import logging
import os import os
import time
import six import six
import shutil
import sys import sys
import tensorflow as tf import tensorflow as tf
import time
from glob import glob
from collections import defaultdict
from ..utils.eval import get_global_step from ..utils.eval import get_global_step
import click
from bob.extension.scripts.click_helper import ( from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption) verbosity_option, ConfigCommand, ResourceOption)
from bob.io.base import create_directories_safe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_best_n_models(train_dir, save_dir, evaluated_file,
keep_n_best_models):
create_directories_safe(save_dir)
evaluated = read_evaluated_file(evaluated_file)
def _key(x):
x = x[1]
ac = x.get('accuracy')
lo = x.get('loss') or 0
return ac * -1 if ac is not None else lo
best_models = dict(sorted(
evaluated.items(), key=_key)[:keep_n_best_models])
# delete the old saved models that are not in top N best anymore
saved_models = defaultdict(list)
for path in glob('{}/model.ckpt-*'.format(save_dir)):
global_step = path.split('model.ckpt-')[1].split('.')[0]
saved_models[global_step].append(path)
for global_step, paths in saved_models.items():
if global_step not in best_models:
for path in paths:
logger.info("Deleting `%s'", path)
os.remove(path)
# copy over the best models if not already there
for global_step in best_models:
for path in glob('{}/model.ckpt-{}.*'.format(train_dir, global_step)):
dst = os.path.join(save_dir, os.path.basename(path))
if os.path.isfile(dst):
continue
logger.info("Copying `%s' over to `%s'", path, dst)
shutil.copy(path, dst)
def read_evaluated_file(path):
evaluated = {}
with open(path) as f:
for line in f:
global_step, line = line.split(' ', 1)
temp = {}
for k_v in line.strip().split(', '):
k, v = k_v.split(' = ')
v = float(v)
if 'global_step' in k:
v = int(v)
temp[k] = v
evaluated[global_step] = temp
return evaluated
def append_evaluated_file(path, evaluations):
str_evaluations = ', '.join(
'%s = %s' % (k, v)
for k, v in sorted(six.iteritems(evaluations)))
with open(path, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'],
str_evaluations))
return str_evaluations
@click.command(entry_point_group='bob.learn.tensorflow.config', @click.command(entry_point_group='bob.learn.tensorflow.config',
cls=ConfigCommand) cls=ConfigCommand)
@click.option('--estimator', '-e', required=True, cls=ResourceOption, @click.option('--estimator', '-e', required=True, cls=ResourceOption,
...@@ -28,12 +94,14 @@ logger = logging.getLogger(__name__) ...@@ -28,12 +94,14 @@ logger = logging.getLogger(__name__)
entry_point_group='bob.learn.tensorflow.hook') entry_point_group='bob.learn.tensorflow.hook')
@click.option('--run-once', cls=ResourceOption, default=False, @click.option('--run-once', cls=ResourceOption, default=False,
show_default=True) show_default=True)
@click.option('--eval-interval-secs', cls=ResourceOption, type=click.types.INT, @click.option('--eval-interval-secs', cls=ResourceOption, type=click.INT,
default=60, show_default=True) default=60, show_default=True)
@click.option('--name', cls=ResourceOption) @click.option('--name', cls=ResourceOption)
@click.option('--keep-n-best-models', '-K', type=click.INT, cls=ResourceOption,
default=0, show_default=True)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
**kwargs): keep_n_best_models, **kwargs):
"""Evaluates networks using Tensorflow estimators. """Evaluates networks using Tensorflow estimators.
\b \b
...@@ -76,18 +144,16 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, ...@@ -76,18 +144,16 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
logger.debug('run_once: %s', run_once) logger.debug('run_once: %s', run_once)
logger.debug('eval_interval_secs: %s', eval_interval_secs) logger.debug('eval_interval_secs: %s', eval_interval_secs)
logger.debug('name: %s', name) logger.debug('name: %s', name)
logger.debug('keep_n_best_models: %s', keep_n_best_models)
logger.debug('kwargs: %s', kwargs) logger.debug('kwargs: %s', kwargs)
if name: real_name = 'eval_' + name if name else 'eval'
real_name = 'eval_' + name eval_dir = os.path.join(estimator.model_dir, real_name)
else: evaluated_file = os.path.join(eval_dir, 'evaluated')
real_name = 'eval'
evaluated_file = os.path.join(estimator.model_dir, real_name, 'evaluated')
while True: while True:
evaluated_steps = [] evaluated_steps = {}
if os.path.exists(evaluated_file): if os.path.exists(evaluated_file):
with open(evaluated_file) as f: evaluated_steps = read_evaluated_file(evaluated_file)
evaluated_steps = [line.split()[0] for line in f]
ckpt = tf.train.get_checkpoint_state(estimator.model_dir) ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path): if (not ckpt) or (not ckpt.model_checkpoint_path):
...@@ -113,14 +179,15 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, ...@@ -113,14 +179,15 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
name=name, name=name,
) )
str_evaluations = ', '.join( str_evaluations = append_evaluated_file(
'%s = %s' % (k, v) evaluated_file, evaluations)
for k, v in sorted(six.iteritems(evaluations))) click.echo(str_evaluations)
print(str_evaluations)
sys.stdout.flush() sys.stdout.flush()
with open(evaluated_file, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'], # Save the best N models into the eval directory
str_evaluations)) save_best_n_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models)
if run_once: if run_once:
break break
time.sleep(eval_interval_secs) time.sleep(eval_interval_secs)
from __future__ import print_function from __future__ import print_function
import os import os
import shutil import shutil
from glob import glob
from tempfile import mkdtemp from tempfile import mkdtemp
from click.testing import CliRunner from click.testing import CliRunner
from bob.io.base.test_utils import datafile from bob.io.base.test_utils import datafile
...@@ -122,7 +123,7 @@ def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord): ...@@ -122,7 +123,7 @@ def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
result.exc_info, result.output, result.exception) result.exc_info, result.output, result.exception)
def _eval(tmpdir, model_dir, dummy_tfrecord): def _eval(tmpdir, model_dir, dummy_tfrecord, extra_args=[]):
config = CONFIG % { config = CONFIG % {
'model_dir': model_dir, 'model_dir': model_dir,
'tfrecord_filenames': dummy_tfrecord 'tfrecord_filenames': dummy_tfrecord
...@@ -131,7 +132,7 @@ def _eval(tmpdir, model_dir, dummy_tfrecord): ...@@ -131,7 +132,7 @@ def _eval(tmpdir, model_dir, dummy_tfrecord):
with open(config_path, 'w') as f: with open(config_path, 'w') as f:
f.write(config) f.write(config)
runner = CliRunner() runner = CliRunner()
result = runner.invoke(eval_script, args=[config_path]) result = runner.invoke(eval_script, args=[config_path] + extra_args)
assert result.exit_code == 0, '%s\n%s\n%s' % ( assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception) result.exc_info, result.output, result.exception)
...@@ -179,3 +180,35 @@ def test_eval(): ...@@ -179,3 +180,35 @@ def test_eval():
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
except Exception: except Exception:
pass pass
def test_eval_keep_n_model():
tmpdir = mkdtemp(prefix='bob_')
try:
model_dir = os.path.join(tmpdir, 'model_dir')
eval_dir = os.path.join(model_dir, 'eval')
print('\nCreating a dummy tfrecord')
dummy_tfrecord = _create_tfrecord(tmpdir)
print('Training a dummy network')
_create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
print('Evaluating a dummy network')
_eval(tmpdir, model_dir, dummy_tfrecord, ['-K', '1'])
evaluated_path = os.path.join(eval_dir, 'evaluated')
assert os.path.exists(evaluated_path), evaluated_path
with open(evaluated_path) as f:
doc = f.read()
assert '1 ' in doc, doc
assert '200 ' in doc, doc
assert len(glob('{}/model.ckpt-*'.format(eval_dir))) == 3, \
os.listdir(eval_dir)
finally:
try:
shutil.rmtree(tmpdir)
except Exception:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment