From 5cf18a79c5d1060b5d7e18a99870790639dae0e3 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Thu, 14 Dec 2017 14:54:17 +0100 Subject: [PATCH] Exit normally with the earlystop hook --- MANIFEST.in | 2 +- .../tensorflow/script/train_and_evaluate.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 9e6e9b02..05e855ef 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ -include README.rst bootstrap-buildout.py buildout.cfg COPYING version.txt requirements.txt +include README.rst buildout.cfg LICENSE version.txt requirements.txt recursive-include doc *.py *.rst recursive-include bob *.wav *.hdf5 *.pickle *.meta *.ckp *.py *.png diff --git a/bob/learn/tensorflow/script/train_and_evaluate.py b/bob/learn/tensorflow/script/train_and_evaluate.py index 015577b8..86ee6827 100644 --- a/bob/learn/tensorflow/script/train_and_evaluate.py +++ b/bob/learn/tensorflow/script/train_and_evaluate.py @@ -30,6 +30,11 @@ The configuration files should have the following objects totally: estimator train_spec eval_spec + + ## Optional objects: + exit_ok_exceptions : [Exception] + A list of exceptions to exit properly if they occur. If nothing is + provided, the EarlyStopException is handled by default. """ from __future__ import absolute_import from __future__ import division @@ -40,6 +45,7 @@ import tensorflow as tf from bob.extension.config import load as read_config_file from bob.learn.tensorflow.utils.commandline import \ get_from_config_or_commandline +from bob.learn.tensorflow.utils.hooks import EarlyStopException from bob.core.log import setup, set_verbosity_level logger = setup(__name__) @@ -62,13 +68,21 @@ def main(argv=None): # Sets-up logging set_verbosity_level(logger, verbosity) - # required arguments + # required objects estimator = config.estimator train_spec = config.train_spec eval_spec = config.eval_spec + # optional objects + exit_ok_exceptions = getattr(config, 'exit_ok_exceptions', + (EarlyStopException,)) + # Train and evaluate - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + try: + tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + except exit_ok_exceptions as e: + logger.exception(e) + return if __name__ == '__main__': -- GitLab