Skip to content
Snippets Groups Projects
Commit e59587c9 authored by Saeed SARFJOO's avatar Saeed SARFJOO
Browse files

set sort-by by string

parent cde558b2
Branches
No related tags found
1 merge request!61add sort_by_accuracy and max_wait_intervals to eval.py
Pipeline #
......@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
def save_n_best_models(train_dir, save_dir, evaluated_file,
keep_n_best_models, sort_by_accuracy):
keep_n_best_models, sort_by):
create_directories_safe(save_dir)
evaluated = read_evaluated_file(evaluated_file)
......@@ -31,7 +31,7 @@ def save_n_best_models(train_dir, save_dir, evaluated_file,
x = x[1]
ac = x.get('accuracy') or 0
lo = x.get('loss') or 0
if sort_by_accuracy:
if sort_by == 'accuracy':
return (ac * -1, lo)
else:
return (lo, ac * -1)
......@@ -149,11 +149,11 @@ def append_evaluated_file(path, evaluations):
help='If more than 0, will keep the best N models in the evaluation folder'
)
@click.option(
'--sort-by-accuracy',
'--sort-by',
cls=ResourceOption,
default=False,
default="loss",
show_default=True,
help='If given, the N best models will be chosen based on accuracy instead of loss.')
help='The metric for sorting the N best models.')
@click.option(
'--max-wait-intervals',
cls=ResourceOption,
......@@ -163,7 +163,7 @@ def append_evaluated_file(path, evaluations):
help='If given, the maximum number of intervals waiting for new training checkpoint.')
@verbosity_option(cls=ResourceOption)
def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
keep_n_best_models, sort_by_accuracy, max_wait_intervals, **kwargs):
keep_n_best_models, sort_by, max_wait_intervals, **kwargs):
"""Evaluates networks using Tensorflow estimators."""
log_parameters(logger)
......@@ -192,7 +192,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
# Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models, sort_by_accuracy)
keep_n_best_models, sort_by)
ckpt = tf.train.get_checkpoint_state(estimator.model_dir)
if (not ckpt) or (not ckpt.model_checkpoint_path):
......@@ -229,7 +229,7 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
# Save the best N models into the eval directory
save_n_best_models(estimator.model_dir, eval_dir, evaluated_file,
keep_n_best_models, sort_by_accuracy)
keep_n_best_models, sort_by)
if run_once:
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment