From 3806dec45b68a46c7c2ce96758a271a2862b59aa Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Thu, 19 Oct 2017 16:26:41 +0200 Subject: [PATCH] Moved the estimators stuff from the trainers to the estimators module Renamed Logits trainer --- .../LogitsTrainer.py => estimators/Logits.py} | 16 ++++++------- bob/learn/tensorflow/estimators/__init__.py | 23 +++++++++++++++++++ .../tensorflow/test/test_image_dataset.py | 2 +- .../test/test_onegraph_estimator.py | 10 ++++---- bob/learn/tensorflow/trainers/__init__.py | 1 - 5 files changed, 37 insertions(+), 15 deletions(-) rename bob/learn/tensorflow/{trainers/LogitsTrainer.py => estimators/Logits.py} (95%) create mode 100755 bob/learn/tensorflow/estimators/__init__.py diff --git a/bob/learn/tensorflow/trainers/LogitsTrainer.py b/bob/learn/tensorflow/estimators/Logits.py similarity index 95% rename from bob/learn/tensorflow/trainers/LogitsTrainer.py rename to bob/learn/tensorflow/estimators/Logits.py index a7d12b01..5c582f82 100755 --- a/bob/learn/tensorflow/trainers/LogitsTrainer.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -22,7 +22,7 @@ import logging logger = logging.getLogger("bob.learn") -class LogitsTrainer(estimator.Estimator): +class Logits(estimator.Estimator): """ NN Trainer whose with logits as last layer @@ -149,12 +149,12 @@ class LogitsTrainer(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) - super(LogitsTrainer, self).__init__(model_fn=_model_fn, - model_dir=model_dir, - config=config) + super(Logits, self).__init__(model_fn=_model_fn, + model_dir=model_dir, + config=config) -class LogitsCenterLossTrainer(estimator.Estimator): +class LogitsCenterLoss(estimator.Estimator): """ NN Trainer whose with logits as last layer @@ -281,6 +281,6 @@ class LogitsCenterLossTrainer(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) - super(LogitsCenterLossTrainer, self).__init__(model_fn=_model_fn, - model_dir=model_dir, - config=config) + super(LogitsCenterLoss, self).__init__(model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py new file mode 100755 index 00000000..04ce0a5e --- /dev/null +++ b/bob/learn/tensorflow/estimators/__init__.py @@ -0,0 +1,23 @@ +from .Logits import Logits, LogitsCenterLoss + +# gets sphinx autodoc done right - don't remove it +def __appropriate__(*args): + """Says object was actually declared here, an not on the import module. + + Parameters: + + *args: An iterable of objects to modify + + Resolves `Sphinx referencing issues + <https://github.com/sphinx-doc/sphinx/issues/3048>` + """ + + for obj in args: obj.__module__ = __name__ + +__appropriate__( + Logits, + LogitsCenterLoss + ) +__all__ = [_ for _ in dir() if not _.startswith('_')] + + diff --git a/bob/learn/tensorflow/test/test_image_dataset.py b/bob/learn/tensorflow/test/test_image_dataset.py index 5e0e1b19..9fdb9827 100755 --- a/bob/learn/tensorflow/test/test_image_dataset.py +++ b/bob/learn/tensorflow/test/test_image_dataset.py @@ -52,6 +52,7 @@ def test_logitstrainer_images(): shutil.rmtree(model_dir, ignore_errors=True) except Exception: pass + def run_logitstrainer_images(trainer): # Cleaning up @@ -93,4 +94,3 @@ def run_logitstrainer_images(trainer): tf.reset_default_graph() assert len(tf.global_variables()) == 0 - diff --git a/bob/learn/tensorflow/test/test_onegraph_estimator.py b/bob/learn/tensorflow/test/test_onegraph_estimator.py index 33cadd80..9a252d8d 100755 --- a/bob/learn/tensorflow/test/test_onegraph_estimator.py +++ b/bob/learn/tensorflow/test/test_onegraph_estimator.py @@ -5,7 +5,7 @@ import tensorflow as tf from bob.learn.tensorflow.network import dummy -from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer +from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation @@ -39,7 +39,7 @@ def test_logitstrainer(): # Trainer logits try: embedding_validation = False - trainer = LogitsTrainer(model_dir=model_dir, + trainer = Logits(model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), n_classes=10, @@ -59,7 +59,7 @@ def test_logitstrainer(): def test_logitstrainer_embedding(): try: embedding_validation = True - trainer = LogitsTrainer(model_dir=model_dir, + trainer = Logits(model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), n_classes=10, @@ -82,7 +82,7 @@ def test_logitstrainer_centerloss(): embedding_validation = False run_config = tf.estimator.RunConfig() run_config = run_config.replace(save_checkpoints_steps=1000) - trainer = LogitsCenterLossTrainer( + trainer = LogitsCenterLoss( model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), @@ -115,7 +115,7 @@ def test_logitstrainer_centerloss(): def test_logitstrainer_centerloss_embedding(): try: embedding_validation = True - trainer = LogitsCenterLossTrainer( + trainer = LogitsCenterLoss( model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), diff --git a/bob/learn/tensorflow/trainers/__init__.py b/bob/learn/tensorflow/trainers/__init__.py index 7812e271..ee0a9819 100755 --- a/bob/learn/tensorflow/trainers/__init__.py +++ b/bob/learn/tensorflow/trainers/__init__.py @@ -3,7 +3,6 @@ from .Trainer import Trainer from .SiameseTrainer import SiameseTrainer from .TripletTrainer import TripletTrainer from .learning_rate import exponential_decay, constant -from .LogitsTrainer import LogitsTrainer, LogitsCenterLossTrainer import numpy -- GitLab