Commit 44e40482 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Implement model saving in bob tf eval. Fixes #54

parent 73aa5774
......@@ -4,20 +4,86 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import click
import logging
import os
import time
import six
import shutil
import sys
import tensorflow as tf
import time
from glob import glob
from collections import defaultdict
from ..utils.eval import get_global_step
import click
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption)
from bob.io.base import create_directories_safe
logger = logging.getLogger(__name__)
def save_best_n_models(train_dir, save_dir, evaluated_file,
keep_n_best_models):
create_directories_safe(save_dir)
evaluated = read_evaluated_file(evaluated_file)
def _key(x):
x = x[1]
ac = x.get('accuracy')
lo = x.get('loss') or 0
return ac * -1 if ac is not None else lo
best_models = dict(sorted(
evaluated.items(), key=_key)[:keep_n_best_models])
# delete the old saved models that are not in top N best anymore
saved_models = defaultdict(list)
for path in glob('{}/model.ckpt-*'.format(save_dir)):
global_step = path.split('model.ckpt-')[1].split('.')[0]
saved_models[global_step].append(path)
for global_step, paths in saved_models.items():
if global_step not in best_models:
for path in paths:
logger.info("Deleting `%s'", path)
os.remove(path)
# copy over the best models if not already there
for global_step in best_models:
for path in glob('{}/model.ckpt-{}.*'.format(train_dir, global_step)):
dst = os.path.join(save_dir, os.path.basename(path))
if os.path.isfile(dst):
continue
logger.info("Copying `%s' over to `%s'", path, dst)
shutil.copy(path, dst)
def read_evaluated_file(path):
evaluated = {}
with open(path) as f:
for line in f:
global_step, line = line.split(' ', 1)
temp = {}
for k_v in line.strip().split(', '):
k, v = k_v.split(' = ')
v = float(v)
if 'global_step' in k:
v = int(v)
temp[k] = v
evaluated[global_step] = temp
return evaluated
def append_evaluated_file(path, evaluations):
str_evaluations = ', '.join(
'%s = %s' % (k, v)
for k, v in sorted(six.iteritems(evaluations)))
with open(path, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'],
str_evaluations))
return str_evaluations
@click.command(entry_point_group='bob.learn.tensorflow.config',
cls=ConfigCommand)
@click.option('--estimator', '-e', required=True, cls=ResourceOption,
......@@ -28,12 +94,14 @@ logger = logging.getLogger(__name__)
entry_point_group='bob.learn.tensorflow.hook')
@click.option('--run-once', cls=ResourceOption, default=False,
show_default=True)
@click.option('--eval-interval-secs', cls=ResourceOption, type=click.types.INT,
@click.option('--eval-interval-secs', cls=ResourceOption, type=click.INT,
default=60, show_default=True)
@click.option('--name', cls=ResourceOption)
@click.option('--keep-n-best-models', '-K', type=click.INT, cls=ResourceOption,
default=0, show_default=True)
@verbosity_option(cls=ResourceOption)
def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
**kwargs):
keep_n_best_models, **kwargs):
"""Evaluates networks using Tensorflow estimators.
\b
......@@ -76,18 +144,16 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
logger.debug('run_once: %s', run_once)
logger.debug('eval_interval_secs: %s', eval_interval_secs)
logger.debug('name: %s', name)
logger.debug('keep_n_best_models: %s', keep_n_best_models)
logger.debug('kwargs: %s', kwargs)
if name:
real_name = 'eval_' + name
else:
real_name = 'eval'
evaluated_file = os.path.join(estimator.model_dir, real_name, 'evaluated')
real_name = 'eval_' + name if name else 'eval'
eval_dir = os.path.join(estimator.model_dir, real_name)
evaluated_file = os.path.join(eval_dir, 'evaluated')
while True:
evaluated_steps = []
evaluated_steps = {}
if os.path.exists(evaluated_file):
with open(evaluated_file) as f:
evaluated_steps = [line.split()[0] for line in f]
evaluated_steps = read_evaluated_file(evaluated_file)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path):
......@@ -113,14 +179,15 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
name=name,
)
str_evaluations = ', '.join(
'%s = %s' % (k, v)
for k, v in sorted(six.iteritems(evaluations)))
print(str_evaluations)
str_evaluations = append_evaluated_file(
evaluated_file, evaluations)
click.echo(str_evaluations)
sys.stdout.flush()
with open(evaluated_file, 'a') as f:
f.write('{} {}\n'.format(evaluations['global_step'],
str_evaluations))
# Save the best N models into the eval directory
save_best_n_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models)
if run_once:
break
time.sleep(eval_interval_secs)
from __future__ import print_function
import os
import shutil
from glob import glob
from tempfile import mkdtemp
from click.testing import CliRunner
from bob.io.base.test_utils import datafile
......@@ -122,7 +123,7 @@ def _create_checkpoint(tmpdir, model_dir, dummy_tfrecord):
result.exc_info, result.output, result.exception)
def _eval(tmpdir, model_dir, dummy_tfrecord):
def _eval(tmpdir, model_dir, dummy_tfrecord, extra_args=[]):
config = CONFIG % {
'model_dir': model_dir,
'tfrecord_filenames': dummy_tfrecord
......@@ -131,7 +132,7 @@ def _eval(tmpdir, model_dir, dummy_tfrecord):
with open(config_path, 'w') as f:
f.write(config)
runner = CliRunner()
result = runner.invoke(eval_script, args=[config_path])
result = runner.invoke(eval_script, args=[config_path] + extra_args)
assert result.exit_code == 0, '%s\n%s\n%s' % (
result.exc_info, result.output, result.exception)
......@@ -179,3 +180,35 @@ def test_eval():
shutil.rmtree(tmpdir)
except Exception:
pass
def test_eval_keep_n_model():
tmpdir = mkdtemp(prefix='bob_')
try:
model_dir = os.path.join(tmpdir, 'model_dir')
eval_dir = os.path.join(model_dir, 'eval')
print('\nCreating a dummy tfrecord')
dummy_tfrecord = _create_tfrecord(tmpdir)
print('Training a dummy network')
_create_checkpoint(tmpdir, model_dir, dummy_tfrecord)
print('Evaluating a dummy network')
_eval(tmpdir, model_dir, dummy_tfrecord, ['-K', '1'])
evaluated_path = os.path.join(eval_dir, 'evaluated')
assert os.path.exists(evaluated_path), evaluated_path
with open(evaluated_path) as f:
doc = f.read()
assert '1 ' in doc, doc
assert '200 ' in doc, doc
assert len(glob('{}/model.ckpt-*'.format(eval_dir))) == 3, \
os.listdir(eval_dir)
finally:
try:
shutil.rmtree(tmpdir)
except Exception:
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment