Skip to content
Snippets Groups Projects
Commit c82154ef authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Allow architectures to provide the logits layer

parent e7445a23
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ class Regressor(estimator.Estimator):
add_histograms=None,
optimize_loss=tf.contrib.layers.optimize_loss,
optimize_loss_learning_rate=None,
architecture_has_logits=False,
):
self.architecture = architecture
self.label_dimension = label_dimension
......@@ -51,12 +52,15 @@ class Regressor(estimator.Estimator):
# Checking if we have some variables/scope that we may want to shut
# down
trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode)
prelogits = self.architecture(
prelogits, end_points = self.architecture(
data, mode=mode, trainable_variables=trainable_variables
)[0]
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
)
if architecture_has_logits:
logits, prelogits = prelogits, end_points["prelogits"]
else:
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
)
predictions = {"predictions": logits, "key": key}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment