From 8082804e3cedf5ccf610e346593dccf7fbb9c58c Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 15:25:38 +0100
Subject: [PATCH] fix imports

---
 bob/learn/tensorflow/utils/hooks.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py
index 78c6dbc0..37adae26 100644
--- a/bob/learn/tensorflow/utils/hooks.py
+++ b/bob/learn/tensorflow/utils/hooks.py
@@ -2,14 +2,13 @@ from datetime import datetime
 from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
 import logging
 import numpy as np
-import six
 import tensorflow as tf
 import time
 
 logger = logging.getLogger(__name__)
 
 
-class TensorSummary(tf.train.SessionRunHook):
+class TensorSummary(tf.estimator.SessionRunHook):
     """Adds the given (scalar) tensors to tensorboard summaries"""
 
     def __init__(self, tensors, tensor_names=None, **kwargs):
@@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook):
             tf.summary.scalar(name, tensor)
 
 
-class LoggerHook(tf.train.SessionRunHook):
+class LoggerHook(tf.estimator.SessionRunHook):
     """Logs loss and runtime."""
 
     def __init__(self, loss, batch_size, log_frequency):
@@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook):
                                 examples_per_sec, sec_per_batch))
 
 
-class LoggerHookEstimator(tf.train.SessionRunHook):
+class LoggerHookEstimator(tf.estimator.SessionRunHook):
     """Logs loss and runtime."""
 
     def __init__(self, estimator, batch_size, log_frequency):
@@ -93,7 +92,7 @@ class EarlyStopException(Exception):
     pass
 
 
-class EarlyStopping(tf.train.SessionRunHook):
+class EarlyStopping(tf.estimator.SessionRunHook):
     """Stop training when a monitored quantity has stopped improving.
     Based on Keras's EarlyStopping callback:
     https://keras.io/callbacks/#earlystopping
@@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook):
 
     def begin(self):
         self.values = []
-        if isinstance(self.monitor, six.string_types):
+        if isinstance(self.monitor, str):
             self.monitor = _as_graph_element(self.monitor)
         else:
             self.monitor = _as_graph_element(self.monitor.name)
-- 
GitLab