Skip to content
Snippets Groups Projects
Commit f8449547 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

remove hooks

parent bdcb031a
No related branches found
No related tags found
1 merge request!85Porting to TF2
from bob.extension.config import load as read_config_files
from bob.io.base.test_utils import datafile
from bob.learn.tensorflow.estimators import Logits
from bob.learn.tensorflow.loss.BaseLoss import mean_cross_entropy_loss
from bob.learn.tensorflow.utils.hooks import EarlyStopping, EarlyStopException
import nose
import tensorflow as tf
import shutil
from nose.plugins.attrib import attr
# @nose.tools.raises(EarlyStopException)
# @attr('slow')
# def test_early_stopping_linear_classifier():
# config = read_config_files([
# datafile('mnist_input_fn.py', __name__),
# datafile('mnist_estimator.py', __name__),
# ])
# estimator = config.estimator
# train_input_fn = config.train_input_fn
# eval_input_fn = config.eval_input_fn
# hooks = [
# EarlyStopping(
# 'linear/metrics/accuracy/total', min_delta=0.001, patience=1),
# ]
# train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
# eval_spec = tf.estimator.EvalSpec(
# input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
# try:
# tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# finally:
# shutil.rmtree(estimator.model_dir)
# @nose.tools.raises(EarlyStopException)
# @attr('slow')
# def test_early_stopping_logit_trainer():
# config = read_config_files([
# datafile('mnist_input_fn.py', __name__),
# ])
# train_input_fn = config.train_input_fn
# eval_input_fn = config.eval_input_fn
# hooks = [
# EarlyStopping('accuracy/value', min_delta=0.001, patience=1),
# ]
# train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn)
# eval_spec = tf.estimator.EvalSpec(
# input_fn=eval_input_fn, hooks=hooks, throttle_secs=2, steps=10)
# def architecture(data, mode, **kwargs):
# return data, dict()
# optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-1)
# loss_op = mean_cross_entropy_loss
# estimator = Logits(
# architecture, optimizer, loss_op, n_classes=10, model_dir=None)
# try:
# tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
# finally:
# shutil.rmtree(estimator.model_dir)
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
from bob.learn.tensorflow.network import dummy
from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss
from bob.learn.tensorflow.dataset.image import shuffle_data_and_labels_image_augmentation
import pkg_resources
from bob.learn.tensorflow.utils.hooks import LoggerHookEstimator
from bob.learn.tensorflow.loss import mean_cross_entropy_loss
from nose.plugins.attrib import attr
import shutil
import os
model_dir = "./temp"
learning_rate = 0.1
data_shape = (250, 250, 3) # size of atnt images
data_type = tf.float32
batch_size = 16
validation_batch_size = 250
epochs = 1
steps = 5000
@attr('slow')
def test_logitstrainer_images():
# Trainer logits
try:
embedding_validation = False
trainer = Logits(
model_dir=model_dir,
architecture=dummy,
optimizer=tf.train.GradientDescentOptimizer(learning_rate),
n_classes=10,
loss_op=mean_cross_entropy_loss,
embedding_validation=embedding_validation,
validation_batch_size=validation_batch_size,
apply_moving_averages=False)
run_logitstrainer_images(trainer)
finally:
try:
shutil.rmtree(model_dir, ignore_errors=True)
except Exception:
pass
def run_logitstrainer_images(trainer):
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
filenames = [
pkg_resources.resource_filename(
__name__, 'data/dummy_image_database/m301_01_p01_i0_0.png'),
pkg_resources.resource_filename(
__name__, 'data/dummy_image_database/m301_01_p02_i0_0.png'),
pkg_resources.resource_filename(
__name__, 'data/dummy_image_database/m304_01_p01_i0_0.png'),
pkg_resources.resource_filename(
__name__, 'data/dummy_image_database/m304_02_f12_i0_0.png')
]
labels = [0, 0, 1, 1]
def input_fn():
return shuffle_data_and_labels_image_augmentation(
filenames,
labels,
data_shape,
data_type,
batch_size,
epochs=epochs)
def input_fn_validation():
return shuffle_data_and_labels_image_augmentation(
filenames,
labels,
data_shape,
data_type,
validation_batch_size,
epochs=1000)
hooks = [
LoggerHookEstimator(trainer, 16, 300),
tf.train.SummarySaverHook(
save_steps=1000,
output_dir=model_dir,
scaffold=tf.train.Scaffold(),
summary_writer=tf.summary.FileWriter(model_dir))
]
trainer.train(input_fn, steps=steps, hooks=hooks)
acc = trainer.evaluate(input_fn_validation)
assert acc['accuracy'] > 0.30, acc['accuracy']
# Cleaning up
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
......@@ -59,6 +59,5 @@ def test_embedding_accuracy_tensors():
data = tf.convert_to_tensor(data.astype("float32"))
labels = tf.convert_to_tensor(labels.astype("int64"))
sess = tf.Session()
accuracy = sess.run(compute_embedding_accuracy_tensors(data, labels))
accuracy = compute_embedding_accuracy_tensors(data, labels)
assert accuracy == 1.
......@@ -6,6 +6,5 @@ from .eval import *
from .keras import *
from .train import *
from .graph import *
from .network import *
from .math import *
from .reproducible import *
from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging
import numpy as np
import tensorflow as tf
import time
logger = logging.getLogger(__name__)
class TensorSummary(tf.estimator.SessionRunHook):
"""Adds the given (scalar) tensors to tensorboard summaries"""
def __init__(self, tensors, tensor_names=None, **kwargs):
super().__init__(**kwargs)
self.tensors = list(tensors)
if tensor_names is None:
tensor_names = [t.name for t in self.tensors]
self.tensor_names = list(tensor_names)
def begin(self):
for name, tensor in zip(self.tensor_names, self.tensors):
tf.summary.scalar(name, tensor)
class LoggerHook(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, loss, batch_size, log_frequency):
self.loss = loss
self.batch_size = batch_size
self.log_frequency = log_frequency
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(self.loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % self.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = self.log_frequency * self.batch_size / duration
sec_per_batch = float(duration / self.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
class LoggerHookEstimator(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, estimator, batch_size, log_frequency):
self.estimator = estimator
self.batch_size = batch_size
self.log_frequency = log_frequency
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
# Asks for loss value.
return tf.train.SessionRunArgs(self.estimator.loss)
def after_run(self, run_context, run_values):
if self._step % self.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = self.log_frequency * self.batch_size / duration
sec_per_batch = float(duration / self.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
class EarlyStopException(Exception):
pass
class EarlyStopping(tf.estimator.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback:
https://keras.io/callbacks/#earlystopping
The original implementation worked for epochs. Currently there is no way
to know the epoch count in estimator training. Hence, the criteria is done
using steps instead of epochs.
Parameters
----------
monitor
quantity to be monitored.
min_delta
minimum change in the monitored quantity to qualify as an improvement,
i.e. an absolute change of less than min_delta, will count as no
improvement.
patience
number of steps with no improvement after which training will be
stopped. Please use large patience values since this hook is
implemented using steps instead of epochs compared to the equivalent
one in Keras.
mode
one of {auto, min, max}. In `min` mode, training will stop when the
quantity monitored has stopped decreasing; in `max` mode it will stop
when the quantity monitored has stopped increasing; in `auto` mode, the
direction is automatically inferred from the name of the monitored
quantity.
"""
def __init__(self,
monitor='accuracy/value',
min_delta=0,
patience=0,
mode='auto'):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
if mode not in ['auto', 'min', 'max']:
logger.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
# Allow instances to be re-used
self.wait = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.global_step_of_best = 0
def begin(self):
self.values = []
if isinstance(self.monitor, str):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
self.global_step_tensor = tf.train.get_global_step()
def before_run(self, run_context):
return tf.train.SessionRunArgs([self.monitor, self.global_step_tensor])
def after_run(self, run_context, run_values):
monitor, global_step = run_values.results
self.values.append(monitor)
# global step does not change during evaluation so keeping one of them
# is enough.
self.global_step_value = global_step
def _should_stop(self):
current = np.mean(self.values)
logger.info(
'%s is currently at %f (at step of %d) and the best value was %f '
'(at step of %d)', self.monitor.name, current,
self.global_step_value, self.best, self.global_step_of_best)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
self.global_step_of_best = self.global_step_value
else:
if self.wait >= self.patience:
message = 'Early stopping happened with {} at best of ' \
'{} (at step {}) and current of {} (at step {})'.format(
self.monitor.name, self.best, self.global_step_of_best,
current, self.global_step_value)
logger.info(message)
raise EarlyStopException(message)
self.wait += 1
def end(self, session):
self._should_stop()
import tensorflow.keras.backend as K
from .network import is_trainable
import tensorflow as tf
import logging
logger = logging.getLogger(__name__)
def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
"""
Check if a variable is trainable or not
Parameters
----------
name: str
Layer name
trainable_variables: list
List containing the variables or scopes to be trained.
If None, the variable/scope is trained
"""
# if mode is not training, so we shutdown
if mode != tf.estimator.ModeKeys.TRAIN:
return False
# If None, we train by default
if trainable_variables is None:
return True
# Here is my choice to shutdown the whole scope
return name in trainable_variables
def keras_channels_index():
return -3 if K.image_data_format() == "channels_first" else -1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment