Keras gotchas
Using Keras with estimators (without using tf.keras.estimator.model_to_estimator) is really weird. I am opening an issue here to keep track of the gotchas.
Look at this guide: https://www.tensorflow.org/beta/guide/migration_guide#using_a_custom_model_fn which explains what you should do but it does not cover everything.
- Keras variables do not go to variable stores. To use
tf.train.init_from_checkpointwith Keras variables, one needs to pass explicitly the list of variables to the function. Something like this:assignment_map = {v.name.split(":")[0]: v for v in model.variables} tf.train.init_from_checkpoint( ckpt_dir_or_file=model_folder, assignment_map=assignment_map ) - Keras layers (especially batch norm) do not update
tf.GraphKeys.UPDATE_OPScollections. Hence you have to add those manually:# Add batch norm updates to the graph for update_op in model.get_updates_for(inputs) + model.get_updates_for(None): tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op) - Keras layers' variables go to global trainable variables (weird enough because you cannot use init_from_checkpoint on them). Doing something like:
will not remove those from that list. To use
for layer in model.layers: layer.trainable = Falsetf.contrib.layers.optimize_losswith keras layers, you have to do something like:Otherwise, you will be training all layers.tf.contrib.layers.optimize_loss( ... variables=model.trainable_variables ) - In Keras Models,
model.variablesandmodel.trainable_variablesare different. So you would handle L2 loss like this:or you do something like this:# Add L2 losses to the graph regularization_loss = 0.0 l2 = tf.keras.regularizers.l2(weight_decay) for variable in model.trainable_variables: regularization_loss += l2(variable) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, regularization_loss)# Get both the unconditional losses (the None part) # and the input-conditional losses (the features part). reg_losses = model.get_losses_for(None) + model.get_losses_for(features) - You have to name every layer/model explicitly otherwise you end up with different names depending on the state of keras layers ...
Edited by Amir MOHAMMADI