Commit 1d0cca35 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

bug fixes in estimators

parent 88eea1d2
Pipeline #30335 canceled with stage
in 29 minutes and 29 seconds
......@@ -121,8 +121,9 @@ class Logits(estimator.Estimator):
self.loss = None
self.embedding_validation = embedding_validation
self.extra_checkpoint = extra_checkpoint
self.apply_moving_averages = apply_moving_averages
if apply_moving_averages and isinstance(optimizer, tf.train.Optimizer):
if self.apply_moving_averages and isinstance(optimizer, tf.train.Optimizer):
logger.info("Encapsulating the optimizer with the MovingAverageOptimizer")
optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer)
......@@ -284,7 +285,7 @@ class Logits(estimator.Estimator):
)
# Get the moving average saver after optimizer.minimize is called
if apply_moving_averages:
if self.apply_moving_averages:
self.saver, self.scaffold = moving_average_scaffold(
self.optimizer.optimizer
if hasattr(self.optimizer, "optimizer")
......@@ -354,7 +355,6 @@ class LogitsCenterLoss(estimator.Estimator):
):
self.architecture = architecture
self.optimizer = optimizer
self.n_classes = n_classes
self.alpha = alpha
self.factor = factor
......@@ -368,7 +368,7 @@ class LogitsCenterLoss(estimator.Estimator):
if self.architecture is None:
raise ValueError("Please specify a function to build the architecture !!")
if self.optimizer is None:
if optimizer is None:
raise ValueError(
"Please specify a optimizer (https://www.tensorflow.org/"
"api_guides/python/train) !!"
......@@ -380,6 +380,7 @@ class LogitsCenterLoss(estimator.Estimator):
if self.apply_moving_averages and isinstance(optimizer, tf.train.Optimizer):
logger.info("Encapsulating the optimizer with the MovingAverageOptimizer")
optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer)
self.optimizer = optimizer
def _model_fn(features, labels, mode, config):
......
......@@ -116,7 +116,7 @@ class Siamese(estimator.Estimator):
loss=self.loss,
global_step=tf.train.get_or_create_global_step(),
optimizer=self.optimizer,
learning_rate=self.learning_rate,
learning_rate=self.optimize_loss_learning_rate,
)
# add histograms summaries
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment