Skip to content
Snippets Groups Projects
Commit 7e227594 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Mention at which step did the best value happen

parent 65f88ff0
Branches
Tags
1 merge request!38Early stopping hook
Pipeline #
......@@ -141,6 +141,7 @@ class EarlyStopping(tf.train.SessionRunHook):
# Allow instances to be re-used
self.wait = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.global_step_of_best = 0
def begin(self):
self.values = []
......@@ -148,26 +149,36 @@ class EarlyStopping(tf.train.SessionRunHook):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
self.global_step_tensor = tf.train.get_global_step()
def before_run(self, run_context):
return tf.train.SessionRunArgs(self.monitor)
return tf.train.SessionRunArgs([self.monitor, self.global_step_tensor])
def after_run(self, run_context, run_values):
self.values.append(run_values.results)
monitor, global_step = run_values.results
self.values.append(monitor)
# global step does not change during evaluation so keeping one of them
# is enough.
self.global_step_value = global_step
def _should_stop(self):
current = np.mean(self.values)
logger.info('%s is currently at %f and the best value was %f',
self.monitor.name, current, self.best)
logger.info(
'%s is currently at %f (at step of %d) and the best value was %f '
'(at step of %d)', self.monitor.name, current,
self.global_step_value, self.best, self.global_step_of_best)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
self.global_step_of_best = self.global_step_value
else:
if self.wait >= self.patience:
raise EarlyStopException(
'Early stopping happened with {} at best of '
'{} and current of {}'.format(
self.monitor, self.best, current))
message = 'Early stopping happened with {} at best of ' \
'{} (at step {}) and current of {} (at step {})'.format(
self.monitor.name, self.best, self.global_step_of_best,
current, self.global_step_value)
logger.info(message)
raise EarlyStopException(message)
self.wait += 1
def end(self, session):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment