diff --git a/cnn_training/arcface.py b/cnn_training/arcface.py
index 611837038dcce7fe3c1ee78968753e9e45e03dac..712cf6dca83019dae9adea274a5622b491a295b9 100644
--- a/cnn_training/arcface.py
+++ b/cnn_training/arcface.py
@@ -173,12 +173,19 @@ def create_model(
     )
 
     embeddings = pre_model.get_layer("embeddings").output
+    labels = tf.keras.layers.Input([], name="label")
 
     if pre_train:
-        pre_model = add_top(pre_model, n_classes=n_classes)
+        # pre_model = add_top(pre_model, n_classes=n_classes)
+        logits_premodel = ArcFaceLayer(
+            n_classes,
+            s=model_spec["arcface"]["s"],
+            m=model_spec["arcface"]["m"],
+            arc=False,
+        )(embeddings, None)
 
         # Wrapping the embedding validation
-        logits_premodel = pre_model.get_layer("logits").output
+        # logits_premodel = pre_model.get_layer("logits").output
 
         pre_model = EmbeddingValidation(
             pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
@@ -187,9 +194,11 @@ def create_model(
     ################################
     ## Creating the specific models
     if "arcface" in model_spec:
-        labels = tf.keras.layers.Input([], name="label")
         logits_arcface = ArcFaceLayer(
-            n_classes, s=model_spec["arcface"]["s"], m=model_spec["arcface"]["m"]
+            n_classes,
+            s=model_spec["arcface"]["s"],
+            m=model_spec["arcface"]["m"],
+            arc=True,
         )(embeddings, labels)
         arc_model = ArcFaceModel(
             inputs=(pre_model.input, labels), outputs=[logits_arcface, embeddings]
@@ -253,7 +262,9 @@ def build_and_compile_models(
             optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"],
         )
 
-    arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
+    arc_model.compile(
+        optimizer=optimizer, loss=cross_entropy, run_eagerly=True, metrics=["accuracy"]
+    )
 
     return pre_model, arc_model
 
@@ -365,17 +376,20 @@ def train_and_evaluate(
     callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
 
     # Train the Cross-entropy model if the case
+
     if pre_train:
         # STEPS_PER_EPOCH
         pre_model.fit(
             train_ds,
-            epochs=2,
+            epochs=20,
             validation_data=val_ds,
             steps_per_epoch=STEPS_PER_EPOCH,
             validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
             callbacks=callbacks,
             verbose=2,
         )
+        # Copying the last variable
+        arc_model.trainable_variables[-1].assign(pre_model.trainable_variables[-1])
 
     # STEPS_PER_EPOCH
     # epochs=epochs * KERAS_EPOCH_MULTIPLIER,