From 8a322bcc7d464e28061cd5d84f08771c7254b0af Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Sun, 14 Feb 2021 17:58:06 +0100
Subject: [PATCH] Improvements in the arcface trainer

---
 cnn_training/arcface.py | 84 ++++++++++++++++++++++++++++-------------
 1 file changed, 57 insertions(+), 27 deletions(-)

diff --git a/cnn_training/arcface.py b/cnn_training/arcface.py
index fb593509..61183703 100644
--- a/cnn_training/arcface.py
+++ b/cnn_training/arcface.py
@@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord"
 
 
 Usage:
-    arcface.py <config-yaml> <checkpoint_path> 
+    arcface.py <config-yaml> <checkpoint_path> [--pre-train]
     arcface.py -h | --help
 
 Options:
   -h --help             Show this screen.
+  --pre-train           If set pretrains the CNN with the crossentropy softmax for 2 epochs  
   arcface.py arcface -h | help
 
 """
@@ -67,6 +68,7 @@ from functools import partial
 import pkg_resources
 import tensorflow as tf
 from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
+from bob.learn.tensorflow.models import resnet50v1
 from bob.learn.tensorflow.metrics import predict_using_tensors
 from tensorflow.keras import layers
 from bob.learn.tensorflow.callbacks import add_backup_callback
@@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2
 BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0
 BACKBONES["resnet50"] = tf.keras.applications.ResNet50
 BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2
+# from bob.learn.tensorflow.models.lenet5 import LeNet5_simplified
+
+BACKBONES["resnet50v1"] = resnet50v1
 
 ##############################
 # SOLVER SPECIFICATIONS
@@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38
 
 
 def create_model(
-    n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape
+    n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape, pre_train
 ):
 
     if backbone == "inception-resnet-v2":
@@ -166,20 +171,18 @@ def create_model(
     pre_model = add_bottleneck(
         pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate
     )
-    pre_model = add_top(pre_model, n_classes=n_classes)
 
-    float32_layer = layers.Activation("linear", dtype="float32")
+    embeddings = pre_model.get_layer("embeddings").output
 
-    embeddings = tf.nn.l2_normalize(
-        pre_model.get_layer("embeddings/BatchNorm").output, axis=1
-    )
+    if pre_train:
+        pre_model = add_top(pre_model, n_classes=n_classes)
 
-    logits_premodel = float32_layer(pre_model.get_layer("logits").output)
+        # Wrapping the embedding validation
+        logits_premodel = pre_model.get_layer("logits").output
 
-    # Wrapping the embedding validation
-    pre_model = EmbeddingValidation(
-        pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
-    )
+        pre_model = EmbeddingValidation(
+            pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
+        )
 
     ################################
     ## Creating the specific models
@@ -221,17 +224,34 @@ def create_model(
 
 
 def build_and_compile_models(
-    n_classes, optimizer, model_spec, backbone, bottleneck, dropout_rate, input_shape
+    n_classes,
+    optimizer,
+    model_spec,
+    backbone,
+    bottleneck,
+    dropout_rate,
+    input_shape,
+    pre_train,
 ):
     pre_model, arc_model = create_model(
-        n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape
+        n_classes,
+        model_spec,
+        backbone,
+        bottleneck,
+        dropout_rate,
+        input_shape,
+        pre_train,
     )
 
     cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
         from_logits=True, name="cross_entropy"
     )
 
-    pre_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
+    # Compile the Cross-entropy model if the case
+    if pre_train:
+        pre_model.compile(
+            optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"],
+        )
 
     arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
 
@@ -252,6 +272,7 @@ def train_and_evaluate(
     face_size,
     validation_path,
     lerning_rate_schedule,
+    pre_train=False,
 ):
 
     # number of training steps to do before validating a model. This also defines an epoch
@@ -300,6 +321,7 @@ def train_and_evaluate(
         bottleneck=bottleneck,
         dropout_rate=dropout_rate,
         input_shape=OUTPUT_SHAPE + (3,),
+        pre_train=pre_train,
     )
 
     def scheduler(epoch, lr):
@@ -312,8 +334,10 @@ def train_and_evaluate(
 
         if epoch in range(200):
             return 1 * lr
+        elif epoch < 1000:
+            return lr * np.exp(-0.005)
         else:
-            return lr * tf.math.exp(-0.01)
+            return 0.0001
 
     if lerning_rate_schedule == "cosine-decay-restarts":
         decay_steps = 50
@@ -339,16 +363,19 @@ def train_and_evaluate(
     }
 
     callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
-    # STEPS_PER_EPOCH
-    pre_model.fit(
-        train_ds,
-        epochs=2,
-        validation_data=val_ds,
-        steps_per_epoch=STEPS_PER_EPOCH,
-        validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
-        callbacks=callbacks,
-        verbose=2,
-    )
+
+    # Train the Cross-entropy model if the case
+    if pre_train:
+        # STEPS_PER_EPOCH
+        pre_model.fit(
+            train_ds,
+            epochs=2,
+            validation_data=val_ds,
+            steps_per_epoch=STEPS_PER_EPOCH,
+            validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
+            callbacks=callbacks,
+            verbose=2,
+        )
 
     # STEPS_PER_EPOCH
     # epochs=epochs * KERAS_EPOCH_MULTIPLIER,
@@ -404,6 +431,9 @@ if __name__ == "__main__":
         dropout_rate=float(config["dropout-rate"]),
         face_size=int(config["face-size"]),
         validation_path=config["validation-tf-record-path"],
-        lerning_rate_schedule=config["lerning-rate-schedule"],
+        lerning_rate_schedule=config["lerning-rate-schedule"]
+        if "lerning-rate-schedule" in config
+        else None,
+        pre_train=args["--pre-train"],
     )
 
-- 
GitLab