Commit 65f88ff0 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

fix earlystopping. Implement tests

parent ac0057a3
Pipeline #14617 failed with stages
in 19 minutes and 51 seconds
import tensorflow as tf
data = tf.feature_column.numeric_column('data', shape=[784])
estimator = tf.estimator.LinearClassifier(
feature_columns=[data], n_classes=10)
from bob.db.mnist import Database
import tensorflow as tf
database = Database()
def input_fn(mode):
if mode == tf.estimator.ModeKeys.TRAIN:
groups = 'train'
num_epochs = None
shuffle = True
else:
groups = 'test'
num_epochs = 1
shuffle = True
data, labels = database.data(groups=groups)
return tf.estimator.inputs.numpy_input_fn(
x={"data": data.astype('float32'), 'key': labels.astype('float32')},
y=labels.astype('int32'),
batch_size=128,
num_epochs=num_epochs,
shuffle=shuffle)
train_input_fn = input_fn(tf.estimator.ModeKeys.TRAIN)
eval_input_fn = input_fn(tf.estimator.ModeKeys.EVAL)
from bob.extension.config import load as read_config_files
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.estimators import Logits
from bob.learn.tensorflow.loss.BaseLoss import mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import EarlyStopping, EarlyStopException
import nose
import tensorflow as tf
@nose.tools.raises(EarlyStopException)
def test_early_stopping_linear_classifier():
config = read_config_files([
datafile('mnist_input_fn.py', __name__),
datafile('mnist_estimator.py', __name__),
])
estimator = config.estimator
train_input_fn = config.train_input_fn
eval_input_fn = config.eval_input_fn
hooks = [
EarlyStopping('linear/head/metrics/accuracy/value',
min_delta=0.001, patience=1),
]
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
@nose.tools.raises(EarlyStopException)
def test_early_stopping_logit_trainer():
config = read_config_files([
datafile('mnist_input_fn.py', __name__),
])
train_input_fn = config.train_input_fn
eval_input_fn = config.eval_input_fn
hooks = [
EarlyStopping('accuracy/value', min_delta=0.001, patience=1),
]
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
def architecture(data, mode, **kwargs):
return data, dict()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-1)
loss_op = mean_cross_entropy_loss
estimator = Logits(architecture, optimizer, loss_op,
n_classes=10, model_dir=None)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
import tensorflow as tf
import time
from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging
import numpy as np
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import six
import tensorflow as tf
import time
logger = logging.getLogger(__name__)
class LoggerHook(tf.train.SessionRunHook):
......@@ -71,6 +74,10 @@ class LoggerHookEstimator(tf.train.SessionRunHook):
examples_per_sec, sec_per_batch))
class EarlyStopException(Exception):
pass
class EarlyStopping(tf.train.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback:
......@@ -101,7 +108,7 @@ class EarlyStopping(tf.train.SessionRunHook):
"""
def __init__(self,
monitor='accuracy/total',
monitor='accuracy/value',
min_delta=0,
patience=0,
mode='auto'):
......@@ -113,8 +120,8 @@ class EarlyStopping(tf.train.SessionRunHook):
self.wait = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode)
logger.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode)
mode = 'auto'
if mode == 'min':
......@@ -131,25 +138,37 @@ class EarlyStopping(tf.train.SessionRunHook):
self.min_delta *= 1
else:
self.min_delta *= -1
def begin(self):
# Allow instances to be re-used
self.wait = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.monitor = _as_graph_element(self.monitor)
def begin(self):
self.values = []
if isinstance(self.monitor, six.string_types):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
def before_run(self, run_context):
return tf.train.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values):
current = run_values.results
self.values.append(run_values.results)
def _should_stop(self):
current = np.mean(self.values)
logger.info('%s is currently at %f and the best value was %f',
self.monitor.name, current, self.best)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
if self.wait >= self.patience:
run_context.request_stop()
print('Early stopping happened with {} at best of {} and '
'current of {}'.format(
self.monitor, self.best, current))
raise EarlyStopException(
'Early stopping happened with {} at best of '
'{} and current of {}'.format(
self.monitor, self.best, current))
self.wait += 1
def end(self, session):
self._should_stop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment