Commit 638290a3 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

New rule to shutdown gradient variables

parent e12fdb85
...@@ -22,7 +22,7 @@ def append_logits(graph, ...@@ -22,7 +22,7 @@ def append_logits(graph,
reuse=reuse) reuse=reuse)
def is_trainable(name, trainable_variables): def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
""" """
Check if a variable is trainable or not Check if a variable is trainable or not
...@@ -37,9 +37,14 @@ def is_trainable(name, trainable_variables): ...@@ -37,9 +37,14 @@ def is_trainable(name, trainable_variables):
If None, the variable/scope is trained If None, the variable/scope is trained
""" """
# if mode is not training, so we shutdown
if mode != tf.estimator.ModeKeys.TRAIN:
return False
# If None, we train by default # If None, we train by default
if trainable_variables is None: if trainable_variables is None:
return True return True
# Here is my choice to shutdown the whole scope # Here is my choice to shutdown the whole scope
return name in trainable_variables return name in trainable_variables
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment