Commit d46cb5c1 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Allow architectures to provide the logits layer

parent 68d6ca2f
......@@ -25,6 +25,7 @@ class Regressor(estimator.Estimator):
add_histograms=None,
optimize_loss=tf.contrib.layers.optimize_loss,
optimize_loss_learning_rate=None,
architecture_has_logits=False,
):
self.architecture = architecture
self.label_dimension = label_dimension
......@@ -51,9 +52,12 @@ class Regressor(estimator.Estimator):
# Checking if we have some variables/scope that we may want to shut
# down
trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode)
prelogits = self.architecture(
prelogits, end_points = self.architecture(
data, mode=mode, trainable_variables=trainable_variables
)[0]
)
if architecture_has_logits:
logits, prelogits = prelogits, end_points["prelogits"]
else:
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment