From 3806dec45b68a46c7c2ce96758a271a2862b59aa Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 19 Oct 2017 16:26:41 +0200
Subject: [PATCH] Moved the estimators stuff from the trainers to the
 estimators module

Renamed Logits trainer
---
 .../LogitsTrainer.py => estimators/Logits.py} | 16 ++++++-------
 bob/learn/tensorflow/estimators/__init__.py   | 23 +++++++++++++++++++
 .../tensorflow/test/test_image_dataset.py     |  2 +-
 .../test/test_onegraph_estimator.py           | 10 ++++----
 bob/learn/tensorflow/trainers/__init__.py     |  1 -
 5 files changed, 37 insertions(+), 15 deletions(-)
 rename bob/learn/tensorflow/{trainers/LogitsTrainer.py => estimators/Logits.py} (95%)
 create mode 100755 bob/learn/tensorflow/estimators/__init__.py

diff --git a/bob/learn/tensorflow/trainers/LogitsTrainer.py b/bob/learn/tensorflow/estimators/Logits.py
similarity index 95%
rename from bob/learn/tensorflow/trainers/LogitsTrainer.py
rename to bob/learn/tensorflow/estimators/Logits.py
index a7d12b01..5c582f82 100755
--- a/bob/learn/tensorflow/trainers/LogitsTrainer.py
+++ b/bob/learn/tensorflow/estimators/Logits.py
@@ -22,7 +22,7 @@ import logging
 logger = logging.getLogger("bob.learn")
 
 
-class LogitsTrainer(estimator.Estimator):
+class Logits(estimator.Estimator):
     """
     NN Trainer whose with logits as last layer
 
@@ -149,12 +149,12 @@ class LogitsTrainer(estimator.Estimator):
                 return tf.estimator.EstimatorSpec(
                     mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
 
-        super(LogitsTrainer, self).__init__(model_fn=_model_fn,
-                                            model_dir=model_dir,
-                                            config=config)
+        super(Logits, self).__init__(model_fn=_model_fn,
+                                     model_dir=model_dir,
+                                     config=config)
 
 
-class LogitsCenterLossTrainer(estimator.Estimator):
+class LogitsCenterLoss(estimator.Estimator):
     """
     NN Trainer whose with logits as last layer
 
@@ -281,6 +281,6 @@ class LogitsCenterLossTrainer(estimator.Estimator):
                 return tf.estimator.EstimatorSpec(
                     mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops)
 
-        super(LogitsCenterLossTrainer, self).__init__(model_fn=_model_fn,
-                                                      model_dir=model_dir,
-                                                      config=config)
+        super(LogitsCenterLoss, self).__init__(model_fn=_model_fn,
+                                               model_dir=model_dir,
+                                               config=config)
diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py
new file mode 100755
index 00000000..04ce0a5e
--- /dev/null
+++ b/bob/learn/tensorflow/estimators/__init__.py
@@ -0,0 +1,23 @@
+from .Logits import Logits, LogitsCenterLoss
+
+# gets sphinx autodoc done right - don't remove it
+def __appropriate__(*args):
+  """Says object was actually declared here, an not on the import module.
+
+  Parameters:
+
+    *args: An iterable of objects to modify
+
+  Resolves `Sphinx referencing issues
+  <https://github.com/sphinx-doc/sphinx/issues/3048>`
+  """
+
+  for obj in args: obj.__module__ = __name__
+
+__appropriate__(
+    Logits,
+    LogitsCenterLoss
+    )
+__all__ = [_ for _ in dir() if not _.startswith('_')]
+
+
diff --git a/bob/learn/tensorflow/test/test_image_dataset.py b/bob/learn/tensorflow/test/test_image_dataset.py
index 5e0e1b19..9fdb9827 100755
--- a/bob/learn/tensorflow/test/test_image_dataset.py
+++ b/bob/learn/tensorflow/test/test_image_dataset.py
@@ -52,6 +52,7 @@ def test_logitstrainer_images():
             shutil.rmtree(model_dir, ignore_errors=True)
         except Exception:
             pass        
+    
 
 def run_logitstrainer_images(trainer):
     # Cleaning up
@@ -93,4 +94,3 @@ def run_logitstrainer_images(trainer):
     tf.reset_default_graph()
     assert len(tf.global_variables()) == 0
 
-
diff --git a/bob/learn/tensorflow/test/test_onegraph_estimator.py b/bob/learn/tensorflow/test/test_onegraph_estimator.py
index 33cadd80..9a252d8d 100755
--- a/bob/learn/tensorflow/test/test_onegraph_estimator.py
+++ b/bob/learn/tensorflow/test/test_onegraph_estimator.py
@@ -5,7 +5,7 @@
 import tensorflow as tf
 
 from bob.learn.tensorflow.network import dummy
-from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer
+from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss
 
 from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation
 
@@ -39,7 +39,7 @@ def test_logitstrainer():
     # Trainer logits
     try:
         embedding_validation = False
-        trainer = LogitsTrainer(model_dir=model_dir,
+        trainer = Logits(model_dir=model_dir,
                                 architecture=dummy,
                                 optimizer=tf.train.GradientDescentOptimizer(learning_rate),
                                 n_classes=10,
@@ -59,7 +59,7 @@ def test_logitstrainer():
 def test_logitstrainer_embedding():
     try:
         embedding_validation = True
-        trainer = LogitsTrainer(model_dir=model_dir,
+        trainer = Logits(model_dir=model_dir,
                                 architecture=dummy,
                                 optimizer=tf.train.GradientDescentOptimizer(learning_rate),
                                 n_classes=10,
@@ -82,7 +82,7 @@ def test_logitstrainer_centerloss():
         embedding_validation = False
         run_config = tf.estimator.RunConfig()
         run_config = run_config.replace(save_checkpoints_steps=1000)
-        trainer = LogitsCenterLossTrainer(
+        trainer = LogitsCenterLoss(
                                 model_dir=model_dir,
                                 architecture=dummy,
                                 optimizer=tf.train.GradientDescentOptimizer(learning_rate),
@@ -115,7 +115,7 @@ def test_logitstrainer_centerloss():
 def test_logitstrainer_centerloss_embedding():
     try:
         embedding_validation = True
-        trainer = LogitsCenterLossTrainer(
+        trainer = LogitsCenterLoss(
                                 model_dir=model_dir,
                                 architecture=dummy,
                                 optimizer=tf.train.GradientDescentOptimizer(learning_rate),
diff --git a/bob/learn/tensorflow/trainers/__init__.py b/bob/learn/tensorflow/trainers/__init__.py
index 7812e271..ee0a9819 100755
--- a/bob/learn/tensorflow/trainers/__init__.py
+++ b/bob/learn/tensorflow/trainers/__init__.py
@@ -3,7 +3,6 @@ from .Trainer import Trainer
 from .SiameseTrainer import SiameseTrainer
 from .TripletTrainer import TripletTrainer
 from .learning_rate import exponential_decay, constant
-from .LogitsTrainer import LogitsTrainer, LogitsCenterLossTrainer
 import numpy
 
 
-- 
GitLab