diff --git a/bob/learn/tensorflow/trainers/LogitsTrainer.py b/bob/learn/tensorflow/estimators/Logits.py similarity index 95% rename from bob/learn/tensorflow/trainers/LogitsTrainer.py rename to bob/learn/tensorflow/estimators/Logits.py index a7d12b0145add029fdd2836fd98998a8cbea784a..5c582f82d57def58da29ae6a2300b362d1d78b70 100755 --- a/bob/learn/tensorflow/trainers/LogitsTrainer.py +++ b/bob/learn/tensorflow/estimators/Logits.py @@ -22,7 +22,7 @@ import logging logger = logging.getLogger("bob.learn") -class LogitsTrainer(estimator.Estimator): +class Logits(estimator.Estimator): """ NN Trainer whose with logits as last layer @@ -149,12 +149,12 @@ class LogitsTrainer(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) - super(LogitsTrainer, self).__init__(model_fn=_model_fn, - model_dir=model_dir, - config=config) + super(Logits, self).__init__(model_fn=_model_fn, + model_dir=model_dir, + config=config) -class LogitsCenterLossTrainer(estimator.Estimator): +class LogitsCenterLoss(estimator.Estimator): """ NN Trainer whose with logits as last layer @@ -281,6 +281,6 @@ class LogitsCenterLossTrainer(estimator.Estimator): return tf.estimator.EstimatorSpec( mode=mode, loss=self.loss, eval_metric_ops=eval_metric_ops) - super(LogitsCenterLossTrainer, self).__init__(model_fn=_model_fn, - model_dir=model_dir, - config=config) + super(LogitsCenterLoss, self).__init__(model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..04ce0a5e06b13e1b8cc7cb52b3f14632f4185ac4 --- /dev/null +++ b/bob/learn/tensorflow/estimators/__init__.py @@ -0,0 +1,23 @@ +from .Logits import Logits, LogitsCenterLoss + +# gets sphinx autodoc done right - don't remove it +def __appropriate__(*args): + """Says object was actually declared here, an not on the import module. + + Parameters: + + *args: An iterable of objects to modify + + Resolves `Sphinx referencing issues + <https://github.com/sphinx-doc/sphinx/issues/3048>` + """ + + for obj in args: obj.__module__ = __name__ + +__appropriate__( + Logits, + LogitsCenterLoss + ) +__all__ = [_ for _ in dir() if not _.startswith('_')] + + diff --git a/bob/learn/tensorflow/test/test_image_dataset.py b/bob/learn/tensorflow/test/test_image_dataset.py index 5e0e1b1960de4f4337372cc3ae6f2afb3f481d6a..9fdb9827012308abd0ef830dec65d74752b7265f 100755 --- a/bob/learn/tensorflow/test/test_image_dataset.py +++ b/bob/learn/tensorflow/test/test_image_dataset.py @@ -52,6 +52,7 @@ def test_logitstrainer_images(): shutil.rmtree(model_dir, ignore_errors=True) except Exception: pass + def run_logitstrainer_images(trainer): # Cleaning up @@ -93,4 +94,3 @@ def run_logitstrainer_images(trainer): tf.reset_default_graph() assert len(tf.global_variables()) == 0 - diff --git a/bob/learn/tensorflow/test/test_onegraph_estimator.py b/bob/learn/tensorflow/test/test_onegraph_estimator.py index 33cadd805f3b55f9e8866302f74d561913e1b465..9a252d8d7fd15ac6fbf1dad31f527dc7076a4478 100755 --- a/bob/learn/tensorflow/test/test_onegraph_estimator.py +++ b/bob/learn/tensorflow/test/test_onegraph_estimator.py @@ -5,7 +5,7 @@ import tensorflow as tf from bob.learn.tensorflow.network import dummy -from bob.learn.tensorflow.trainers import LogitsTrainer, LogitsCenterLossTrainer +from bob.learn.tensorflow.estimators import Logits, LogitsCenterLoss from bob.learn.tensorflow.dataset.tfrecords import shuffle_data_and_labels, batch_data_and_labels, shuffle_data_and_labels_image_augmentation @@ -39,7 +39,7 @@ def test_logitstrainer(): # Trainer logits try: embedding_validation = False - trainer = LogitsTrainer(model_dir=model_dir, + trainer = Logits(model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), n_classes=10, @@ -59,7 +59,7 @@ def test_logitstrainer(): def test_logitstrainer_embedding(): try: embedding_validation = True - trainer = LogitsTrainer(model_dir=model_dir, + trainer = Logits(model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), n_classes=10, @@ -82,7 +82,7 @@ def test_logitstrainer_centerloss(): embedding_validation = False run_config = tf.estimator.RunConfig() run_config = run_config.replace(save_checkpoints_steps=1000) - trainer = LogitsCenterLossTrainer( + trainer = LogitsCenterLoss( model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), @@ -115,7 +115,7 @@ def test_logitstrainer_centerloss(): def test_logitstrainer_centerloss_embedding(): try: embedding_validation = True - trainer = LogitsCenterLossTrainer( + trainer = LogitsCenterLoss( model_dir=model_dir, architecture=dummy, optimizer=tf.train.GradientDescentOptimizer(learning_rate), diff --git a/bob/learn/tensorflow/trainers/__init__.py b/bob/learn/tensorflow/trainers/__init__.py index 7812e271417bf040e33716472f564ee9f187b349..ee0a98197281f67b7defe48a59d91ff672bae6fa 100755 --- a/bob/learn/tensorflow/trainers/__init__.py +++ b/bob/learn/tensorflow/trainers/__init__.py @@ -3,7 +3,6 @@ from .Trainer import Trainer from .SiameseTrainer import SiameseTrainer from .TripletTrainer import TripletTrainer from .learning_rate import exponential_decay, constant -from .LogitsTrainer import LogitsTrainer, LogitsCenterLossTrainer import numpy