Skip to content
Snippets Groups Projects
Commit bbd87082 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Hacked a hook. Still to be fixed

parent 2ab444c4
No related branches found
No related tags found
1 merge request!21Resolve "Adopt to the Estimators API"
......@@ -33,3 +33,34 @@ class LoggerHook(tf.train.SessionRunHook):
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
class LoggerHookEstimator(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def __init__(self, estimator, batch_size, log_frequency):
self.estimator = estimator
self.batch_size = batch_size
self.log_frequency = log_frequency
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(self.estimator.loss) # Asks for loss value.
def after_run(self, run_context, run_values):
if self._step % self.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
loss_value = run_values.results
examples_per_sec = self.log_frequency * self.batch_size / duration
sec_per_batch = float(duration / self.log_frequency)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment