From f720d2ed15afdbe6c70ad9c6e01d04468589ecad Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 15:25:38 +0100
Subject: [PATCH] fix imports
---
bob/learn/tensorflow/utils/hooks.py | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py
index 78c6dbc0..37adae26 100644
--- a/bob/learn/tensorflow/utils/hooks.py
+++ b/bob/learn/tensorflow/utils/hooks.py
@@ -2,14 +2,13 @@ from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging
import numpy as np
-import six
import tensorflow as tf
import time
logger = logging.getLogger(__name__)
-class TensorSummary(tf.train.SessionRunHook):
+class TensorSummary(tf.estimator.SessionRunHook):
"""Adds the given (scalar) tensors to tensorboard summaries"""
def __init__(self, tensors, tensor_names=None, **kwargs):
@@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook):
tf.summary.scalar(name, tensor)
-class LoggerHook(tf.train.SessionRunHook):
+class LoggerHook(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, loss, batch_size, log_frequency):
@@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook):
examples_per_sec, sec_per_batch))
-class LoggerHookEstimator(tf.train.SessionRunHook):
+class LoggerHookEstimator(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, estimator, batch_size, log_frequency):
@@ -93,7 +92,7 @@ class EarlyStopException(Exception):
pass
-class EarlyStopping(tf.train.SessionRunHook):
+class EarlyStopping(tf.estimator.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback:
https://keras.io/callbacks/#earlystopping
@@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook):
def begin(self):
self.values = []
- if isinstance(self.monitor, six.string_types):
+ if isinstance(self.monitor, str):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
--
GitLab