Improvements in the arcface trainer

parent 0e02e70f
......@@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord"
Usage:
arcface.py <config-yaml> <checkpoint_path>
arcface.py <config-yaml> <checkpoint_path> [--pre-train]
arcface.py -h | --help
Options:
-h --help Show this screen.
--pre-train If set pretrains the CNN with the crossentropy softmax for 2 epochs
arcface.py arcface -h | help
"""
......@@ -67,6 +68,7 @@ from functools import partial
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.models import resnet50v1
from bob.learn.tensorflow.metrics import predict_using_tensors
from tensorflow.keras import layers
from bob.learn.tensorflow.callbacks import add_backup_callback
......@@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2
BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0
BACKBONES["resnet50"] = tf.keras.applications.ResNet50
BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2
# from bob.learn.tensorflow.models.lenet5 import LeNet5_simplified
BACKBONES["resnet50v1"] = resnet50v1
##############################
# SOLVER SPECIFICATIONS
......@@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38
def create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape, pre_train
):
if backbone == "inception-resnet-v2":
......@@ -166,17 +171,15 @@ def create_model(
pre_model = add_bottleneck(
pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate
)
pre_model = add_top(pre_model, n_classes=n_classes)
float32_layer = layers.Activation("linear", dtype="float32")
embeddings = pre_model.get_layer("embeddings").output
embeddings = tf.nn.l2_normalize(
pre_model.get_layer("embeddings/BatchNorm").output, axis=1
)
logits_premodel = float32_layer(pre_model.get_layer("logits").output)
if pre_train:
pre_model = add_top(pre_model, n_classes=n_classes)
# Wrapping the embedding validation
logits_premodel = pre_model.get_layer("logits").output
pre_model = EmbeddingValidation(
pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
)
......@@ -221,17 +224,34 @@ def create_model(
def build_and_compile_models(
n_classes, optimizer, model_spec, backbone, bottleneck, dropout_rate, input_shape
n_classes,
optimizer,
model_spec,
backbone,
bottleneck,
dropout_rate,
input_shape,
pre_train,
):
pre_model, arc_model = create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape
n_classes,
model_spec,
backbone,
bottleneck,
dropout_rate,
input_shape,
pre_train,
)
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy"
)
pre_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
# Compile the Cross-entropy model if the case
if pre_train:
pre_model.compile(
optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"],
)
arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
......@@ -252,6 +272,7 @@ def train_and_evaluate(
face_size,
validation_path,
lerning_rate_schedule,
pre_train=False,
):
# number of training steps to do before validating a model. This also defines an epoch
......@@ -300,6 +321,7 @@ def train_and_evaluate(
bottleneck=bottleneck,
dropout_rate=dropout_rate,
input_shape=OUTPUT_SHAPE + (3,),
pre_train=pre_train,
)
def scheduler(epoch, lr):
......@@ -312,8 +334,10 @@ def train_and_evaluate(
if epoch in range(200):
return 1 * lr
elif epoch < 1000:
return lr * np.exp(-0.005)
else:
return lr * tf.math.exp(-0.01)
return 0.0001
if lerning_rate_schedule == "cosine-decay-restarts":
decay_steps = 50
......@@ -339,6 +363,9 @@ def train_and_evaluate(
}
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
# Train the Cross-entropy model if the case
if pre_train:
# STEPS_PER_EPOCH
pre_model.fit(
train_ds,
......@@ -404,6 +431,9 @@ if __name__ == "__main__":
dropout_rate=float(config["dropout-rate"]),
face_size=int(config["face-size"]),
validation_path=config["validation-tf-record-path"],
lerning_rate_schedule=config["lerning-rate-schedule"],
lerning_rate_schedule=config["lerning-rate-schedule"]
if "lerning-rate-schedule" in config
else None,
pre_train=args["--pre-train"],
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment