Skip to content
Snippets Groups Projects

Resolve "Adopt to the Estimators API"

Merged Tiago de Freitas Pereira requested to merge 40-adopt-to-the-estimators-api into master
1 file
+ 31
0
Compare changes
  • Side-by-side
  • Inline
@@ -33,3 +33,34 @@ class LoggerHook(tf.train.SessionRunHook):
@@ -33,3 +33,34 @@ class LoggerHook(tf.train.SessionRunHook):
'sec/batch)')
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
examples_per_sec, sec_per_batch))
 
 
class LoggerHookEstimator(tf.train.SessionRunHook):
 
"""Logs loss and runtime."""
 
 
def __init__(self, estimator, batch_size, log_frequency):
 
self.estimator = estimator
 
self.batch_size = batch_size
 
self.log_frequency = log_frequency
 
 
def begin(self):
 
self._step = -1
 
self._start_time = time.time()
 
 
def before_run(self, run_context):
 
self._step += 1
 
return tf.train.SessionRunArgs(self.estimator.loss) # Asks for loss value.
 
 
def after_run(self, run_context, run_values):
 
if self._step % self.log_frequency == 0:
 
current_time = time.time()
 
duration = current_time - self._start_time
 
self._start_time = current_time
 
 
loss_value = run_values.results
 
examples_per_sec = self.log_frequency * self.batch_size / duration
 
sec_per_batch = float(duration / self.log_frequency)
 
 
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
 
'sec/batch)')
 
print(format_str % (datetime.now(), self._step, loss_value,
 
examples_per_sec, sec_per_batch))
Loading