Commit d46cb5c1 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Allow architectures to provide the logits layer

parent 68d6ca2f
...@@ -25,6 +25,7 @@ class Regressor(estimator.Estimator): ...@@ -25,6 +25,7 @@ class Regressor(estimator.Estimator):
add_histograms=None, add_histograms=None,
optimize_loss=tf.contrib.layers.optimize_loss, optimize_loss=tf.contrib.layers.optimize_loss,
optimize_loss_learning_rate=None, optimize_loss_learning_rate=None,
architecture_has_logits=False,
): ):
self.architecture = architecture self.architecture = architecture
self.label_dimension = label_dimension self.label_dimension = label_dimension
...@@ -51,12 +52,15 @@ class Regressor(estimator.Estimator): ...@@ -51,12 +52,15 @@ class Regressor(estimator.Estimator):
# Checking if we have some variables/scope that we may want to shut # Checking if we have some variables/scope that we may want to shut
# down # down
trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode) trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode)
prelogits = self.architecture( prelogits, end_points = self.architecture(
data, mode=mode, trainable_variables=trainable_variables data, mode=mode, trainable_variables=trainable_variables
)[0]
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
) )
if architecture_has_logits:
logits, prelogits = prelogits, end_points["prelogits"]
else:
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
)
predictions = {"predictions": logits, "key": key} predictions = {"predictions": logits, "key": key}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment