Keras gotchas

Using Keras with estimators (without using tf.keras.estimator.model_to_estimator) is really weird. I am opening an issue here to keep track of the gotchas.

Look at this guide: https://www.tensorflow.org/beta/guide/migration_guide#using_a_custom_model_fn which explains what you should do but it does not cover everything.

  • Keras variables do not go to variable stores. To use tf.train.init_from_checkpoint with Keras variables, one needs to pass explicitly the list of variables to the function. Something like this:
      assignment_map = {v.name.split(":")[0]: v for v in model.variables}
      tf.train.init_from_checkpoint(
          ckpt_dir_or_file=model_folder, assignment_map=assignment_map
      )
    
  • Keras layers (especially batch norm) do not update tf.GraphKeys.UPDATE_OPS collections. Hence you have to add those manually:
      # Add batch norm updates to the graph
      for update_op in model.get_updates_for(inputs) + model.get_updates_for(None):
          tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op)
  • Keras layers' variables go to global trainable variables (weird enough because you cannot use init_from_checkpoint on them). Doing something like:
    for layer in model.layers:
        layer.trainable = False
    will not remove those from that list. To use tf.contrib.layers.optimize_loss with keras layers, you have to do something like:
    tf.contrib.layers.optimize_loss(
        ...
        variables=model.trainable_variables
    )
    Otherwise, you will be training all layers.
  • In Keras Models, model.variables and model.trainable_variables are different. So you would handle L2 loss like this:
      # Add L2 losses to the graph
      regularization_loss = 0.0
      l2 = tf.keras.regularizers.l2(weight_decay)
      for variable in model.trainable_variables:
          regularization_loss += l2(variable)
      tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, regularization_loss)
    or you do something like this:
      # Get both the unconditional losses (the None part)
      # and the input-conditional losses (the features part).
      reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  • You have to name every layer/model explicitly otherwise you end up with different names depending on the state of keras layers ...
Edited by Amir MOHAMMADI