Skip to content
Snippets Groups Projects
Commit 34978827 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'eval' into 'master'

Improvements to eval script

See merge request !73
parents c7a4d9f7 8ee9b4ae
No related branches found
No related tags found
1 merge request!73Improvements to eval script
Pipeline #27571 failed
......@@ -22,7 +22,7 @@ from bob.io.base import create_directories_safe
logger = logging.getLogger(__name__)
def copy_one_step(train_dir, global_step, save_dir):
def copy_one_step(train_dir, global_step, save_dir, fail_on_error=False):
for path in glob('{}/model.ckpt-{}.*'.format(train_dir, global_step)):
dst = os.path.join(save_dir, os.path.basename(path))
if os.path.isfile(dst):
......@@ -34,24 +34,30 @@ def copy_one_step(train_dir, global_step, save_dir):
logger.warning(
"Failed to copy `%s' over to `%s'", path, dst,
exc_info=True)
if fail_on_error:
raise
def save_n_best_models(train_dir, save_dir, evaluated_file,
keep_n_best_models, sort_by):
keep_n_best_models, sort_by, exceptions=tuple()):
logger.debug(
"save_n_best_models was called with %s, %s, %s, %s, %s, %s",
train_dir, save_dir, evaluated_file, keep_n_best_models, sort_by,
exceptions)
create_directories_safe(save_dir)
evaluated = read_evaluated_file(evaluated_file)
def _key(x):
x = x[1]
ac = x.get('accuracy') or 0
lo = x.get('loss') or 0
if sort_by == 'accuracy':
return (ac * -1, lo)
x = x[1][sort_by]
if 'loss' in sort_by:
return x
else:
return (lo, ac * -1)
return -x
best_models = OrderedDict(
sorted(evaluated.items(), key=_key)[:keep_n_best_models])
logger.info("Best models: %s", best_models)
# delete the old saved models that are not in top N best anymore
saved_models = defaultdict(list)
......@@ -60,7 +66,7 @@ def save_n_best_models(train_dir, save_dir, evaluated_file,
saved_models[global_step].append(path)
for global_step, paths in saved_models.items():
if global_step not in best_models:
if global_step not in best_models and global_step not in exceptions:
for path in paths:
logger.info("Deleting `%s'", path)
os.remove(path)
......@@ -98,7 +104,10 @@ def read_evaluated_file(path):
temp = {}
for k_v in line.strip().split(', '):
k, v = k_v.split(' = ')
v = float(v)
try:
v = float(v)
except ValueError: # not all values could be floats
pass
if 'global_step' in k:
v = int(v)
temp[k] = v
......@@ -189,13 +198,29 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
real_name = 'eval_' + name if name else 'eval'
eval_dir = os.path.join(estimator.model_dir, real_name)
os.makedirs(eval_dir, exist_ok=True)
evaluated_file = os.path.join(eval_dir, 'evaluated')
wait_interval_count = 0
evaluated_steps_count = 0
while True:
evaluated_steps = {}
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if os.path.exists(evaluated_file):
evaluated_steps = read_evaluated_file(evaluated_file)
# create exceptions so we don't delete them
exceptions = []
if ckpt and ckpt.model_checkpoint_path:
for checkpoint_path in ckpt.all_model_checkpoint_paths:
try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
logger.warning("Failed to find global_step.", exc_info=True)
continue
if global_step not in evaluated_steps:
exceptions.append(global_step)
if max_wait_intervals > 0:
new_evaluated_count = len(evaluated_steps.keys())
if new_evaluated_count > 0:
......@@ -209,9 +234,8 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
# Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models, sort_by)
keep_n_best_models, sort_by, exceptions)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path):
if max_wait_intervals > 0:
wait_interval_count += 1
......@@ -224,15 +248,24 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
print('Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path))
logger.warning(
'Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path), exc_info=True)
continue
if global_step in evaluated_steps and not force_re_run:
continue
# copy over the checkpoint before evaluating since it might
# disappear after evaluation.
copy_one_step(estimator.model_dir, global_step, eval_dir)
try:
copy_one_step(estimator.model_dir, global_step, eval_dir, fail_on_error=True)
except Exception:
# skip testing this checkpoint
continue
# evaluate based on the just copied checkpoint_path
checkpoint_path = checkpoint_path.replace(estimator.model_dir, eval_dir)
logger.debug("Evaluating the model from %s", checkpoint_path)
# Evaluate
try:
......@@ -243,9 +276,9 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
checkpoint_path=checkpoint_path,
name=name,
)
# if the model gets deleted before we can evaluate it
except (tf.errors.NotFoundError, ValueError):
break
except Exception:
logger.info("Something went wrong in evaluation.")
raise
str_evaluations = append_evaluated_file(evaluated_file,
evaluations)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment