Skip to content
Snippets Groups Projects
Commit d41b9798 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Allow architectures to provide the logits layer

parent 66ae9cad
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ class Regressor(estimator.Estimator): ...@@ -25,6 +25,7 @@ class Regressor(estimator.Estimator):
add_histograms=None, add_histograms=None,
optimize_loss=tf.contrib.layers.optimize_loss, optimize_loss=tf.contrib.layers.optimize_loss,
optimize_loss_learning_rate=None, optimize_loss_learning_rate=None,
architecture_has_logits=False,
): ):
self.architecture = architecture self.architecture = architecture
self.label_dimension = label_dimension self.label_dimension = label_dimension
...@@ -51,12 +52,15 @@ class Regressor(estimator.Estimator): ...@@ -51,12 +52,15 @@ class Regressor(estimator.Estimator):
# Checking if we have some variables/scope that we may want to shut # Checking if we have some variables/scope that we may want to shut
# down # down
trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode) trainable_variables = get_trainable_variables(extra_checkpoint, mode=mode)
prelogits = self.architecture( prelogits, end_points = self.architecture(
data, mode=mode, trainable_variables=trainable_variables data, mode=mode, trainable_variables=trainable_variables
)[0]
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
) )
if architecture_has_logits:
logits, prelogits = prelogits, end_points["prelogits"]
else:
logits = append_logits(
prelogits, label_dimension, trainable_variables=trainable_variables
)
predictions = {"predictions": logits, "key": key} predictions = {"predictions": logits, "key": key}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment