Skip to content

Created mechanism that allows us to train only parts of the graph during the graph adaptation

Hey,

Sometimes when we want to apply some checkpoint in another dataset (or domain), we may want to choose which parts of the graph we want to re-train.

With this MR, our estimators (Logits.., Triplet, Siamese,) were enhanced with this keyword argument

        extra_checkpoint = {
            "checkpoint_path": <YOUR_CHECKPOINT>,
            "scopes": dict({"<SOURCE_SCOPE>/": "<TARGET_SCOPE>/"}),
            "trainable_variables": [<LIST OF VARIABLES OR SCOPES THAT YOU WANT TO RETRAIN>]
        }

The novelty here is the trainable_variables, where now we can set the parts of the graph we want to do back-propagation. If you set an empty list ( "trainable_variables": []) all variables will not be trainable. If this variable is not set at all, everything is trainable.

This is strongly dependent on how the architecture function is crafted. Look some example on how such functions should be crafted.

Do you have some time to review this one @amohammadi ? Perhaps this can be useful for you.

Thanks

Merge request reports