Skip to content
Snippets Groups Projects
Commit 6cb3a7f5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Improvements to eval script

parent c7a4d9f7
Branches
Tags
1 merge request!73Improvements to eval script
Pipeline #27486 failed
......@@ -37,21 +37,25 @@ def copy_one_step(train_dir, global_step, save_dir):
def save_n_best_models(train_dir, save_dir, evaluated_file,
keep_n_best_models, sort_by):
keep_n_best_models, sort_by, exceptions=tuple()):
logger.debug(
"save_n_best_models was called with %s, %s, %s, %s, %s, %s",
train_dir, save_dir, evaluated_file, keep_n_best_models, sort_by,
exceptions)
create_directories_safe(save_dir)
evaluated = read_evaluated_file(evaluated_file)
def _key(x):
x = x[1]
ac = x.get('accuracy') or 0
lo = x.get('loss') or 0
if sort_by == 'accuracy':
return (ac * -1, lo)
x = x[1][sort_by]
if 'loss' in sort_by:
return x
else:
return (lo, ac * -1)
return -x
best_models = OrderedDict(
sorted(evaluated.items(), key=_key)[:keep_n_best_models])
logger.info("Best models: %s", best_models)
# delete the old saved models that are not in top N best anymore
saved_models = defaultdict(list)
......@@ -60,7 +64,7 @@ def save_n_best_models(train_dir, save_dir, evaluated_file,
saved_models[global_step].append(path)
for global_step, paths in saved_models.items():
if global_step not in best_models:
if global_step not in best_models and global_step not in exceptions:
for path in paths:
logger.info("Deleting `%s'", path)
os.remove(path)
......@@ -98,7 +102,10 @@ def read_evaluated_file(path):
temp = {}
for k_v in line.strip().split(', '):
k, v = k_v.split(' = ')
v = float(v)
try:
v = float(v)
except ValueError: # not all values could be floats
pass
if 'global_step' in k:
v = int(v)
temp[k] = v
......@@ -189,13 +196,29 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
real_name = 'eval_' + name if name else 'eval'
eval_dir = os.path.join(estimator.model_dir, real_name)
os.makedirs(eval_dir, exist_ok=True)
evaluated_file = os.path.join(eval_dir, 'evaluated')
wait_interval_count = 0
evaluated_steps_count = 0
while True:
evaluated_steps = {}
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if os.path.exists(evaluated_file):
evaluated_steps = read_evaluated_file(evaluated_file)
# create exceptions so we don't delete them
exceptions = []
if ckpt and ckpt.model_checkpoint_path:
for checkpoint_path in ckpt.all_model_checkpoint_paths:
try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
logger.warning("Failed to find global_step.", exc_info=True)
continue
if global_step not in evaluated_steps:
exceptions.append(global_step)
if max_wait_intervals > 0:
new_evaluated_count = len(evaluated_steps.keys())
if new_evaluated_count > 0:
......@@ -209,9 +232,8 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
# Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models, sort_by)
keep_n_best_models, sort_by, exceptions)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path):
if max_wait_intervals > 0:
wait_interval_count += 1
......@@ -224,8 +246,9 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
try:
global_step = str(get_global_step(checkpoint_path))
except Exception:
print('Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path))
logger.warning(
'Failed to find global_step for checkpoint_path {}, '
'skipping ...'.format(checkpoint_path), exc_info=True)
continue
if global_step in evaluated_steps and not force_re_run:
continue
......@@ -234,6 +257,10 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
# disappear after evaluation.
copy_one_step(estimator.model_dir, global_step, eval_dir)
# evaluate based on the just copied checkpoint_path
checkpoint_path = checkpoint_path.replace(estimator.model_dir, eval_dir)
logger.debug("Evaluating the model from %s", checkpoint_path)
# Evaluate
try:
evaluations = estimator.evaluate(
......@@ -243,9 +270,9 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
checkpoint_path=checkpoint_path,
name=name,
)
# if the model gets deleted before we can evaluate it
except (tf.errors.NotFoundError, ValueError):
break
except Exception:
logger.info("Something went wrong in evaluation.")
raise
str_evaluations = append_evaluated_file(evaluated_file,
evaluations)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment