diff --git a/cnn_training/arcface.py b/cnn_training/arcface.py index 611837038dcce7fe3c1ee78968753e9e45e03dac..712cf6dca83019dae9adea274a5622b491a295b9 100644 --- a/cnn_training/arcface.py +++ b/cnn_training/arcface.py @@ -173,12 +173,19 @@ def create_model( ) embeddings = pre_model.get_layer("embeddings").output + labels = tf.keras.layers.Input([], name="label") if pre_train: - pre_model = add_top(pre_model, n_classes=n_classes) + # pre_model = add_top(pre_model, n_classes=n_classes) + logits_premodel = ArcFaceLayer( + n_classes, + s=model_spec["arcface"]["s"], + m=model_spec["arcface"]["m"], + arc=False, + )(embeddings, None) # Wrapping the embedding validation - logits_premodel = pre_model.get_layer("logits").output + # logits_premodel = pre_model.get_layer("logits").output pre_model = EmbeddingValidation( pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name @@ -187,9 +194,11 @@ def create_model( ################################ ## Creating the specific models if "arcface" in model_spec: - labels = tf.keras.layers.Input([], name="label") logits_arcface = ArcFaceLayer( - n_classes, s=model_spec["arcface"]["s"], m=model_spec["arcface"]["m"] + n_classes, + s=model_spec["arcface"]["s"], + m=model_spec["arcface"]["m"], + arc=True, )(embeddings, labels) arc_model = ArcFaceModel( inputs=(pre_model.input, labels), outputs=[logits_arcface, embeddings] @@ -253,7 +262,9 @@ def build_and_compile_models( optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"], ) - arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) + arc_model.compile( + optimizer=optimizer, loss=cross_entropy, run_eagerly=True, metrics=["accuracy"] + ) return pre_model, arc_model @@ -365,17 +376,20 @@ def train_and_evaluate( callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup") # Train the Cross-entropy model if the case + if pre_train: # STEPS_PER_EPOCH pre_model.fit( train_ds, - epochs=2, + epochs=20, validation_data=val_ds, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE, callbacks=callbacks, verbose=2, ) + # Copying the last variable + arc_model.trainable_variables[-1].assign(pre_model.trainable_variables[-1]) # STEPS_PER_EPOCH # epochs=epochs * KERAS_EPOCH_MULTIPLIER,