Skip to content
Snippets Groups Projects
Commit 1e552b18 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Create a checkpoint file for eval

parent 29236312
No related branches found
No related tags found
1 merge request!52Implement model saving in bob tf eval. Fixes #54
......@@ -13,7 +13,7 @@ import sys
import tensorflow as tf
import time
from glob import glob
from collections import defaultdict
from collections import defaultdict, OrderedDict
from ..utils.eval import get_global_step
from bob.extension.scripts.click_helper import (
verbosity_option, ConfigCommand, ResourceOption)
......@@ -33,7 +33,7 @@ def save_n_best_models(train_dir, save_dir, evaluated_file,
lo = x.get('loss') or 0
return ac * -1 if ac is not None else lo
best_models = dict(sorted(
best_models = OrderedDict(sorted(
evaluated.items(), key=_key)[:keep_n_best_models])
# delete the old saved models that are not in top N best anymore
......@@ -57,6 +57,21 @@ def save_n_best_models(train_dir, save_dir, evaluated_file,
logger.info("Copying `%s' over to `%s'", path, dst)
shutil.copy(path, dst)
# create a checkpoint file indicating to the best existing model:
# 1. filter non-existing models first
def _filter(x):
return len(glob('{}/model.ckpt-{}.*'.format(save_dir, x[0]))) > 0
best_models = OrderedDict(filter(_filter, best_models.items()))
# 2. create the checkpoint file
with open(os.path.join(save_dir, 'checkpoint'), 'wt') as f:
for i, global_step in enumerate(best_models):
if i == 0:
f.write('model_checkpoint_path: "model.ckpt-{}"\n'.format(
global_step))
f.write('all_model_checkpoint_paths: "model.ckpt-{}"\n'.format(
global_step))
def read_evaluated_file(path):
evaluated = {}
......@@ -155,6 +170,10 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
if os.path.exists(evaluated_file):
evaluated_steps = read_evaluated_file(evaluated_file)
# Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path):
time.sleep(eval_interval_secs)
......
......@@ -74,8 +74,8 @@ def save_predictions(pool, output_dir, key, pred_buffer):
entry_point_group='bob.learn.tensorflow.hook')
@click.option('--predict-keys', '-k', multiple=True, default=None,
cls=ResourceOption)
@click.option('--checkpoint-path', cls=ResourceOption)
@click.option('--multiple-samples', is_flag=True, cls=ResourceOption)
@click.option('--checkpoint-path', '-c', cls=ResourceOption)
@click.option('--multiple-samples', '-m', is_flag=True, cls=ResourceOption)
@click.option('--array', '-t', type=click.INT, default=1, cls=ResourceOption)
@click.option('--force', '-f', is_flag=True, cls=ResourceOption)
@verbosity_option(cls=ResourceOption)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment