Keras gotchas
Using Keras with estimators (without using tf.keras.estimator.model_to_estimator
) is really weird. I am opening an issue here to keep track of the gotchas.
Look at this guide: https://www.tensorflow.org/beta/guide/migration_guide#using_a_custom_model_fn which explains what you should do but it does not cover everything.
- Keras variables do not go to variable stores. To use
tf.train.init_from_checkpoint
with Keras variables, one needs to pass explicitly the list of variables to the function. Something like this:assignment_map = {v.name.split(":")[0]: v for v in model.variables} tf.train.init_from_checkpoint( ckpt_dir_or_file=model_folder, assignment_map=assignment_map )
- Keras layers (especially batch norm) do not update
tf.GraphKeys.UPDATE_OPS
collections. Hence you have to add those manually:# Add batch norm updates to the graph for update_op in model.get_updates_for(inputs) + model.get_updates_for(None): tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)
- Keras layers' variables go to global trainable variables (weird enough because you cannot use init_from_checkpoint on them). Doing something like:
for layer in model.layers: layer.trainable = False
tf.contrib.layers.optimize_loss
with keras layers, you have to do something like:tf.contrib.layers.optimize_loss( ... variables=model.trainable_variables )
- In Keras Models,
model.variables
andmodel.trainable_variables
are different. So you would handle L2 loss like this:# Add L2 losses to the graph regularization_loss = 0.0 l2 = tf.keras.regularizers.l2(weight_decay) for variable in model.trainable_variables: regularization_loss += l2(variable) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, regularization_loss)
# Get both the unconditional losses (the None part) # and the input-conditional losses (the features part). reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
- You have to name every layer/model explicitly otherwise you end up with different names depending on the state of keras layers ...