Skip to content
Snippets Groups Projects
Commit 3bab7d65 authored by Saeed SARFJOO's avatar Saeed SARFJOO
Browse files

add sort_by_accuracy and max_wait_intervals to eval.py

parent 6b0a1c58
Branches
Tags
1 merge request!61add sort_by_accuracy and max_wait_intervals to eval.py
Pipeline #
...@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) ...@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
def save_n_best_models(train_dir, save_dir, evaluated_file, def save_n_best_models(train_dir, save_dir, evaluated_file,
keep_n_best_models): keep_n_best_models, sort_by_accuracy):
create_directories_safe(save_dir) create_directories_safe(save_dir)
evaluated = read_evaluated_file(evaluated_file) evaluated = read_evaluated_file(evaluated_file)
...@@ -31,7 +31,10 @@ def save_n_best_models(train_dir, save_dir, evaluated_file, ...@@ -31,7 +31,10 @@ def save_n_best_models(train_dir, save_dir, evaluated_file,
x = x[1] x = x[1]
ac = x.get('accuracy') or 0 ac = x.get('accuracy') or 0
lo = x.get('loss') or 0 lo = x.get('loss') or 0
return (lo, ac * -1) if sort_by_accuracy:
return (ac * -1, lo)
else:
return (lo, ac * -1)
best_models = OrderedDict( best_models = OrderedDict(
sorted(evaluated.items(), key=_key)[:keep_n_best_models]) sorted(evaluated.items(), key=_key)[:keep_n_best_models])
...@@ -145,26 +148,59 @@ def append_evaluated_file(path, evaluations): ...@@ -145,26 +148,59 @@ def append_evaluated_file(path, evaluations):
show_default=True, show_default=True,
help='If more than 0, will keep the best N models in the evaluation folder' help='If more than 0, will keep the best N models in the evaluation folder'
) )
@click.option(
'--sort-by-accuracy',
cls=ResourceOption,
default=False,
show_default=True,
help='If given, the N best models will be chosen based on accuracy instead of loss.')
@click.option(
'--max-wait-intervals',
cls=ResourceOption,
type=click.INT,
default=-1,
show_default=True,
help='If given, the maximum number of intervals waiting for new training checkpoint.')
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
keep_n_best_models, **kwargs): keep_n_best_models, sort_by_accuracy, max_wait_intervals, **kwargs):
"""Evaluates networks using Tensorflow estimators.""" """Evaluates networks using Tensorflow estimators."""
log_parameters(logger) if not click.get_current_context(True) is None:
log_parameters(logger)
real_name = 'eval_' + name if name else 'eval' real_name = 'eval_' + name if name else 'eval'
eval_dir = os.path.join(estimator.model_dir, real_name) eval_dir = os.path.join(estimator.model_dir, real_name)
evaluated_file = os.path.join(eval_dir, 'evaluated') evaluated_file = os.path.join(eval_dir, 'evaluated')
wait_interval_count = 0
evaluated_steps_count = 0
while True: while True:
evaluated_steps = {} evaluated_steps = {}
if os.path.exists(evaluated_file): if os.path.exists(evaluated_file):
evaluated_steps = read_evaluated_file(evaluated_file) evaluated_steps = read_evaluated_file(evaluated_file)
if max_wait_intervals > 0:
new_evaluated_count = len(evaluated_steps.keys())
if new_evaluated_count > 0:
if new_evaluated_count == evaluated_steps_count:
wait_interval_count += 1
if wait_interval_count > max_wait_intervals:
break
else:
evaluated_steps_count = new_evaluated_count
wait_interval_count = 0
# Save the best N models into the eval directory # Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file, save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models) keep_n_best_models, sort_by_accuracy)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir) ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path): if (not ckpt) or (not ckpt.model_checkpoint_path):
if max_wait_intervals > 0:
wait_interval_count+=1
if wait_interval_count > max_wait_intervals:
break
time.sleep(eval_interval_secs) time.sleep(eval_interval_secs)
continue continue
...@@ -194,7 +230,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, ...@@ -194,7 +230,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
# Save the best N models into the eval directory # Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file, save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models) keep_n_best_models, sort_by_accuracy)
if run_once: if run_once:
break break
......
...@@ -53,7 +53,8 @@ logger = logging.getLogger(__name__) ...@@ -53,7 +53,8 @@ logger = logging.getLogger(__name__)
@verbosity_option(cls=ResourceOption) @verbosity_option(cls=ResourceOption)
def train(estimator, train_input_fn, hooks, steps, max_steps, **kwargs): def train(estimator, train_input_fn, hooks, steps, max_steps, **kwargs):
"""Trains networks using Tensorflow estimators.""" """Trains networks using Tensorflow estimators."""
log_parameters(logger) if not click.get_current_context(True) is None:
log_parameters(logger)
# Train # Train
logger.info("Training a model in %s", estimator.model_dir) logger.info("Training a model in %s", estimator.model_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment