Skip to content
Snippets Groups Projects
Commit 65f88ff0 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

fix earlystopping. Implement tests

parent ac0057a3
No related branches found
No related tags found
1 merge request!38Early stopping hook
Pipeline #
import tensorflow as tf
data = tf.feature_column.numeric_column('data', shape=[784])
estimator = tf.estimator.LinearClassifier(
feature_columns=[data], n_classes=10)
from bob.db.mnist import Database
import tensorflow as tf
database = Database()
def input_fn(mode):
if mode == tf.estimator.ModeKeys.TRAIN:
groups = 'train'
num_epochs = None
shuffle = True
else:
groups = 'test'
num_epochs = 1
shuffle = True
data, labels = database.data(groups=groups)
return tf.estimator.inputs.numpy_input_fn(
x={"data": data.astype('float32'), 'key': labels.astype('float32')},
y=labels.astype('int32'),
batch_size=128,
num_epochs=num_epochs,
shuffle=shuffle)
train_input_fn = input_fn(tf.estimator.ModeKeys.TRAIN)
eval_input_fn = input_fn(tf.estimator.ModeKeys.EVAL)
from bob.extension.config import load as read_config_files
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.estimators import Logits
from bob.learn.tensorflow.loss.BaseLoss import mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import EarlyStopping, EarlyStopException
import nose
import tensorflow as tf
@nose.tools.raises(EarlyStopException)
def test_early_stopping_linear_classifier():
config = read_config_files([
datafile('mnist_input_fn.py', __name__),
datafile('mnist_estimator.py', __name__),
])
estimator = config.estimator
train_input_fn = config.train_input_fn
eval_input_fn = config.eval_input_fn
hooks = [
EarlyStopping('linear/head/metrics/accuracy/value',
min_delta=0.001, patience=1),
]
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
@nose.tools.raises(EarlyStopException)
def test_early_stopping_logit_trainer():
config = read_config_files([
datafile('mnist_input_fn.py', __name__),
])
train_input_fn = config.train_input_fn
eval_input_fn = config.eval_input_fn
hooks = [
EarlyStopping('accuracy/value', min_delta=0.001, patience=1),
]
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
def architecture(data, mode, **kwargs):
return data, dict()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-1)
loss_op = mean_cross_entropy_loss
estimator = Logits(architecture, optimizer, loss_op,
n_classes=10, model_dir=None)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
import tensorflow as tf
import time
from datetime import datetime from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging import logging
import numpy as np import numpy as np
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element import six
import tensorflow as tf
import time
logger = logging.getLogger(__name__)
class LoggerHook(tf.train.SessionRunHook): class LoggerHook(tf.train.SessionRunHook):
...@@ -71,6 +74,10 @@ class LoggerHookEstimator(tf.train.SessionRunHook): ...@@ -71,6 +74,10 @@ class LoggerHookEstimator(tf.train.SessionRunHook):
examples_per_sec, sec_per_batch)) examples_per_sec, sec_per_batch))
class EarlyStopException(Exception):
pass
class EarlyStopping(tf.train.SessionRunHook): class EarlyStopping(tf.train.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving. """Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback: Based on Keras's EarlyStopping callback:
...@@ -101,7 +108,7 @@ class EarlyStopping(tf.train.SessionRunHook): ...@@ -101,7 +108,7 @@ class EarlyStopping(tf.train.SessionRunHook):
""" """
def __init__(self, def __init__(self,
monitor='accuracy/total', monitor='accuracy/value',
min_delta=0, min_delta=0,
patience=0, patience=0,
mode='auto'): mode='auto'):
...@@ -113,8 +120,8 @@ class EarlyStopping(tf.train.SessionRunHook): ...@@ -113,8 +120,8 @@ class EarlyStopping(tf.train.SessionRunHook):
self.wait = 0 self.wait = 0
if mode not in ['auto', 'min', 'max']: if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, ' logger.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode) 'fallback to auto mode.' % mode)
mode = 'auto' mode = 'auto'
if mode == 'min': if mode == 'min':
...@@ -131,25 +138,37 @@ class EarlyStopping(tf.train.SessionRunHook): ...@@ -131,25 +138,37 @@ class EarlyStopping(tf.train.SessionRunHook):
self.min_delta *= 1 self.min_delta *= 1
else: else:
self.min_delta *= -1 self.min_delta *= -1
def begin(self):
# Allow instances to be re-used # Allow instances to be re-used
self.wait = 0 self.wait = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.monitor = _as_graph_element(self.monitor)
def begin(self):
self.values = []
if isinstance(self.monitor, six.string_types):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
def before_run(self, run_context): def before_run(self, run_context):
return tf.train.SessionRunArgs(self.monitor) return tf.train.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values): def after_run(self, run_context, run_values):
current = run_values.results self.values.append(run_values.results)
def _should_stop(self):
current = np.mean(self.values)
logger.info('%s is currently at %f and the best value was %f',
self.monitor.name, current, self.best)
if self.monitor_op(current - self.min_delta, self.best): if self.monitor_op(current - self.min_delta, self.best):
self.best = current self.best = current
self.wait = 0 self.wait = 0
else: else:
if self.wait >= self.patience: if self.wait >= self.patience:
run_context.request_stop() raise EarlyStopException(
print('Early stopping happened with {} at best of {} and ' 'Early stopping happened with {} at best of '
'current of {}'.format( '{} and current of {}'.format(
self.monitor, self.best, current)) self.monitor, self.best, current))
self.wait += 1 self.wait += 1
def end(self, session):
self._should_stop()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment