Commit 20d5c20d authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'early-stop' into 'master'

Early stopping hook

See merge request !38
parents 4add4621 7e227594
Pipeline #14645 failed with stages
in 17 minutes and 15 seconds
import tensorflow as tf
data = tf.feature_column.numeric_column('data', shape=[784])
estimator = tf.estimator.LinearClassifier(
feature_columns=[data], n_classes=10)
from bob.db.mnist import Database
import tensorflow as tf
database = Database()
def input_fn(mode):
if mode == tf.estimator.ModeKeys.TRAIN:
groups = 'train'
num_epochs = None
shuffle = True
else:
groups = 'test'
num_epochs = 1
shuffle = True
data, labels = database.data(groups=groups)
return tf.estimator.inputs.numpy_input_fn(
x={"data": data.astype('float32'), 'key': labels.astype('float32')},
y=labels.astype('int32'),
batch_size=128,
num_epochs=num_epochs,
shuffle=shuffle)
train_input_fn = input_fn(tf.estimator.ModeKeys.TRAIN)
eval_input_fn = input_fn(tf.estimator.ModeKeys.EVAL)
from bob.extension.config import load as read_config_files
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.estimators import Logits
from bob.learn.tensorflow.loss.BaseLoss import mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import EarlyStopping, EarlyStopException
import nose
import tensorflow as tf
@nose.tools.raises(EarlyStopException)
def test_early_stopping_linear_classifier():
config = read_config_files([
datafile('mnist_input_fn.py', __name__),
datafile('mnist_estimator.py', __name__),
])
estimator = config.estimator
train_input_fn = config.train_input_fn
eval_input_fn = config.eval_input_fn
hooks = [
EarlyStopping('linear/head/metrics/accuracy/value',
min_delta=0.001, patience=1),
]
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
@nose.tools.raises(EarlyStopException)
def test_early_stopping_logit_trainer():
config = read_config_files([
datafile('mnist_input_fn.py', __name__),
])
train_input_fn = config.train_input_fn
eval_input_fn = config.eval_input_fn
hooks = [
EarlyStopping('accuracy/value', min_delta=0.001, patience=1),
]
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
def architecture(data, mode, **kwargs):
return data, dict()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-1)
loss_op = mean_cross_entropy_loss
estimator = Logits(architecture, optimizer, loss_op,
n_classes=10, model_dir=None)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging
import numpy as np
import six
import tensorflow as tf
import time
from datetime import datetime
logger = logging.getLogger(__name__)
class LoggerHook(tf.train.SessionRunHook):
......@@ -33,7 +39,8 @@ class LoggerHook(tf.train.SessionRunHook):
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
class LoggerHookEstimator(tf.train.SessionRunHook):
"""Logs loss and runtime."""
......@@ -48,7 +55,8 @@ class LoggerHookEstimator(tf.train.SessionRunHook):
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(self.estimator.loss) # Asks for loss value.
# Asks for loss value.
return tf.train.SessionRunArgs(self.estimator.loss)
def after_run(self, run_context, run_values):
if self._step % self.log_frequency == 0:
......@@ -63,4 +71,115 @@ class LoggerHookEstimator(tf.train.SessionRunHook):
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
examples_per_sec, sec_per_batch))
class EarlyStopException(Exception):
pass
class EarlyStopping(tf.train.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback:
https://keras.io/callbacks/#earlystopping
The original implementation worked for epochs. Currently there is no way
to know the epoch count in estimator training. Hence, the criteria is done
using steps instead of epochs.
Parameters
----------
monitor
quantity to be monitored.
min_delta
minimum change in the monitored quantity to qualify as an improvement,
i.e. an absolute change of less than min_delta, will count as no
improvement.
patience
number of steps with no improvement after which training will be
stopped. Please use large patience values since this hook is
implemented using steps instead of epochs compared to the equivalent
one in Keras.
mode
one of {auto, min, max}. In `min` mode, training will stop when the
quantity monitored has stopped decreasing; in `max` mode it will stop
when the quantity monitored has stopped increasing; in `auto` mode, the
direction is automatically inferred from the name of the monitored
quantity.
"""
def __init__(self,
monitor='accuracy/value',
min_delta=0,
patience=0,
mode='auto'):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
if mode not in ['auto', 'min', 'max']:
logger.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
# Allow instances to be re-used
self.wait = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.global_step_of_best = 0
def begin(self):
self.values = []
if isinstance(self.monitor, six.string_types):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
self.global_step_tensor = tf.train.get_global_step()
def before_run(self, run_context):
return tf.train.SessionRunArgs([self.monitor, self.global_step_tensor])
def after_run(self, run_context, run_values):
monitor, global_step = run_values.results
self.values.append(monitor)
# global step does not change during evaluation so keeping one of them
# is enough.
self.global_step_value = global_step
def _should_stop(self):
current = np.mean(self.values)
logger.info(
'%s is currently at %f (at step of %d) and the best value was %f '
'(at step of %d)', self.monitor.name, current,
self.global_step_value, self.best, self.global_step_of_best)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
self.global_step_of_best = self.global_step_value
else:
if self.wait >= self.patience:
message = 'Early stopping happened with {} at best of ' \
'{} (at step {}) and current of {} (at step {})'.format(
self.monitor.name, self.best, self.global_step_of_best,
current, self.global_step_value)
logger.info(message)
raise EarlyStopException(message)
self.wait += 1
def end(self, session):
self._should_stop()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment