Commit 8082804e authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

fix imports

parent 27e73aec
......@@ -2,14 +2,13 @@ from datetime import datetime
from tensorflow.python.training.basic_session_run_hooks import _as_graph_element
import logging
import numpy as np
import six
import tensorflow as tf
import time
logger = logging.getLogger(__name__)
class TensorSummary(tf.train.SessionRunHook):
class TensorSummary(tf.estimator.SessionRunHook):
"""Adds the given (scalar) tensors to tensorboard summaries"""
def __init__(self, tensors, tensor_names=None, **kwargs):
......@@ -24,7 +23,7 @@ class TensorSummary(tf.train.SessionRunHook):
tf.summary.scalar(name, tensor)
class LoggerHook(tf.train.SessionRunHook):
class LoggerHook(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, loss, batch_size, log_frequency):
......@@ -56,7 +55,7 @@ class LoggerHook(tf.train.SessionRunHook):
examples_per_sec, sec_per_batch))
class LoggerHookEstimator(tf.train.SessionRunHook):
class LoggerHookEstimator(tf.estimator.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, estimator, batch_size, log_frequency):
......@@ -93,7 +92,7 @@ class EarlyStopException(Exception):
pass
class EarlyStopping(tf.train.SessionRunHook):
class EarlyStopping(tf.estimator.SessionRunHook):
"""Stop training when a monitored quantity has stopped improving.
Based on Keras's EarlyStopping callback:
https://keras.io/callbacks/#earlystopping
......@@ -160,7 +159,7 @@ class EarlyStopping(tf.train.SessionRunHook):
def begin(self):
self.values = []
if isinstance(self.monitor, six.string_types):
if isinstance(self.monitor, str):
self.monitor = _as_graph_element(self.monitor)
else:
self.monitor = _as_graph_element(self.monitor.name)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment