From 8082804e3cedf5ccf610e346593dccf7fbb9c58c Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 7 Feb 2020 15:25:38 +0100 Subject: [PATCH] fix imports --- bob/learn/tensorflow/utils/hooks.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py index 78c6dbc0..37adae26 100644 --- a/bob/learn/tensorflow/utils/hooks.py +++ b/bob/learn/tensorflow/utils/hooks.py @@ -2,14 +2,13 @@ from datetime import datetime from tensorflow.python.training.basic_session_run_hooks import _as_graph_element import logging import numpy as np -import six import tensorflow as tf import time logger = logging.getLogger(__name__) -class TensorSummary(tf.train.SessionRunHook): +class TensorSummary(tf.estimator.SessionRunHook): """Adds the given (scalar) tensors to tensorboard summaries""" def __init__(self, tensors, tensor_names=None, **kwargs): @@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook): tf.summary.scalar(name, tensor) -class LoggerHook(tf.train.SessionRunHook): +class LoggerHook(tf.estimator.SessionRunHook): """Logs loss and runtime.""" def __init__(self, loss, batch_size, log_frequency): @@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook): examples_per_sec, sec_per_batch)) -class LoggerHookEstimator(tf.train.SessionRunHook): +class LoggerHookEstimator(tf.estimator.SessionRunHook): """Logs loss and runtime.""" def __init__(self, estimator, batch_size, log_frequency): @@ -93,7 +92,7 @@ class EarlyStopException(Exception): pass -class EarlyStopping(tf.train.SessionRunHook): +class EarlyStopping(tf.estimator.SessionRunHook): """Stop training when a monitored quantity has stopped improving. Based on Keras's EarlyStopping callback: https://keras.io/callbacks/#earlystopping @@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook): def begin(self): self.values = [] - if isinstance(self.monitor, six.string_types): + if isinstance(self.monitor, str): self.monitor = _as_graph_element(self.monitor) else: self.monitor = _as_graph_element(self.monitor.name) -- GitLab