diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py
index 78c6dbc0e0ec8702649eb18c26330c576775aca9..37adae26163d5b8ef0cef188c9408f14b22d7923 100644
--- a/bob/learn/tensorflow/utils/hooks.py
+++ b/bob/learn/tensorflow/utils/hooks.py
@@ -2,14 +2,13 @@ from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging
import numpy as np
-import six
import tensorflow as tf
import time
logger = logging.getLogger(__name__)
-class TensorSummary(tf.train.SessionRunHook):
+class TensorSummary(tf.estimator.SessionRunHook):
"""Adds the given (scalar) tensors to tensorboard summaries"""
def __init__(self, tensors, tensor_names=None, **kwargs):
@@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook):
tf.summary.scalar(name, tensor)
-class LoggerHook(tf.train.SessionRunHook):
+class LoggerHook(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, loss, batch_size, log_frequency):
@@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook):
examples_per_sec, sec_per_batch))
-class LoggerHookEstimator(tf.train.SessionRunHook):
+class LoggerHookEstimator(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, estimator, batch_size, log_frequency):
@@ -93,7 +92,7 @@ class EarlyStopException(Exception):
pass
-class EarlyStopping(tf.train.SessionRunHook):
+class EarlyStopping(tf.estimator.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback:
https://keras.io/callbacks/#earlystopping
@@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook):
def begin(self):
self.values = []
- if isinstance(self.monitor, six.string_types):
+ if isinstance(self.monitor, str):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)