Commit 8082804e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

fix imports

parent 27e73aec
...@@ -2,14 +2,13 @@ from datetime import datetime ...@@ -2,14 +2,13 @@ from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging import logging
import numpy as np import numpy as np
import six
import tensorflow as tf import tensorflow as tf
import time import time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TensorSummary(tf.train.SessionRunHook): class TensorSummary(tf.estimator.SessionRunHook):
"""Adds the given (scalar) tensors to tensorboard summaries""" """Adds the given (scalar) tensors to tensorboard summaries"""
def __init__(self, tensors, tensor_names=None, **kwargs): def __init__(self, tensors, tensor_names=None, **kwargs):
...@@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook): ...@@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook):
tf.summary.scalar(name, tensor) tf.summary.scalar(name, tensor)
class LoggerHook(tf.train.SessionRunHook): class LoggerHook(tf.estimator.SessionRunHook):
"""Logs loss and runtime.""" """Logs loss and runtime."""
def __init__(self, loss, batch_size, log_frequency): def __init__(self, loss, batch_size, log_frequency):
...@@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook): ...@@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook):
examples_per_sec, sec_per_batch)) examples_per_sec, sec_per_batch))
class LoggerHookEstimator(tf.train.SessionRunHook): class LoggerHookEstimator(tf.estimator.SessionRunHook):
"""Logs loss and runtime.""" """Logs loss and runtime."""
def __init__(self, estimator, batch_size, log_frequency): def __init__(self, estimator, batch_size, log_frequency):
...@@ -93,7 +92,7 @@ class EarlyStopException(Exception): ...@@ -93,7 +92,7 @@ class EarlyStopException(Exception):
pass pass
class EarlyStopping(tf.train.SessionRunHook): class EarlyStopping(tf.estimator.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving. """Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback: Based on Keras's EarlyStopping callback:
https://keras.io/callbacks/#earlystopping https://keras.io/callbacks/#earlystopping
...@@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook): ...@@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook):
def begin(self): def begin(self):
self.values = [] self.values = []
if isinstance(self.monitor, six.string_types): if isinstance(self.monitor, str):
self.monitor = _as_graph_element(self.monitor) self.monitor = _as_graph_element(self.monitor)
else: else:
self.monitor = _as_graph_element(self.monitor.name) self.monitor = _as_graph_element(self.monitor.name)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment