diff --git a/bob/learn/tensorflow/utils/hooks.py b/bob/learn/tensorflow/utils/hooks.py index 1875e519a77a11a54f45fe6ab5726322173b17bf..5a702f8d997538d064e510e0068771c55e4ea9c9 100644 --- a/bob/learn/tensorflow/utils/hooks.py +++ b/bob/learn/tensorflow/utils/hooks.py @@ -33,3 +33,34 @@ class LoggerHook(tf.train.SessionRunHook): 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) + +class LoggerHookEstimator(tf.train.SessionRunHook): + """Logs loss and runtime.""" + + def __init__(self, estimator, batch_size, log_frequency): + self.estimator = estimator + self.batch_size = batch_size + self.log_frequency = log_frequency + + def begin(self): + self._step = -1 + self._start_time = time.time() + + def before_run(self, run_context): + self._step += 1 + return tf.train.SessionRunArgs(self.estimator.loss) # Asks for loss value. + + def after_run(self, run_context, run_values): + if self._step % self.log_frequency == 0: + current_time = time.time() + duration = current_time - self._start_time + self._start_time = current_time + + loss_value = run_values.results + examples_per_sec = self.log_frequency * self.batch_size / duration + sec_per_batch = float(duration / self.log_frequency) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print(format_str % (datetime.now(), self._step, loss_value, + examples_per_sec, sec_per_batch))