Improvements in the arcface trainer

parent 0e02e70f
...@@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord" ...@@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord"
Usage: Usage:
arcface.py <config-yaml> <checkpoint_path> arcface.py <config-yaml> <checkpoint_path> [--pre-train]
arcface.py -h | --help arcface.py -h | --help
Options: Options:
-h --help Show this screen. -h --help Show this screen.
--pre-train If set pretrains the CNN with the crossentropy softmax for 2 epochs
arcface.py arcface -h | help arcface.py arcface -h | help
""" """
...@@ -67,6 +68,7 @@ from functools import partial ...@@ -67,6 +68,7 @@ from functools import partial
import pkg_resources import pkg_resources
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2 from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.models import resnet50v1
from bob.learn.tensorflow.metrics import predict_using_tensors from bob.learn.tensorflow.metrics import predict_using_tensors
from tensorflow.keras import layers from tensorflow.keras import layers
from bob.learn.tensorflow.callbacks import add_backup_callback from bob.learn.tensorflow.callbacks import add_backup_callback
...@@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2 ...@@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2
BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0 BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0
BACKBONES["resnet50"] = tf.keras.applications.ResNet50 BACKBONES["resnet50"] = tf.keras.applications.ResNet50
BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2 BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2
# from bob.learn.tensorflow.models.lenet5 import LeNet5_simplified
BACKBONES["resnet50v1"] = resnet50v1
############################## ##############################
# SOLVER SPECIFICATIONS # SOLVER SPECIFICATIONS
...@@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38 ...@@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38
def create_model( def create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape, pre_train
): ):
if backbone == "inception-resnet-v2": if backbone == "inception-resnet-v2":
...@@ -166,20 +171,18 @@ def create_model( ...@@ -166,20 +171,18 @@ def create_model(
pre_model = add_bottleneck( pre_model = add_bottleneck(
pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate
) )
pre_model = add_top(pre_model, n_classes=n_classes)
float32_layer = layers.Activation("linear", dtype="float32") embeddings = pre_model.get_layer("embeddings").output
embeddings = tf.nn.l2_normalize( if pre_train:
pre_model.get_layer("embeddings/BatchNorm").output, axis=1 pre_model = add_top(pre_model, n_classes=n_classes)
)
logits_premodel = float32_layer(pre_model.get_layer("logits").output) # Wrapping the embedding validation
logits_premodel = pre_model.get_layer("logits").output
# Wrapping the embedding validation pre_model = EmbeddingValidation(
pre_model = EmbeddingValidation( pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name )
)
################################ ################################
## Creating the specific models ## Creating the specific models
...@@ -221,17 +224,34 @@ def create_model( ...@@ -221,17 +224,34 @@ def create_model(
def build_and_compile_models( def build_and_compile_models(
n_classes, optimizer, model_spec, backbone, bottleneck, dropout_rate, input_shape n_classes,
optimizer,
model_spec,
backbone,
bottleneck,
dropout_rate,
input_shape,
pre_train,
): ):
pre_model, arc_model = create_model( pre_model, arc_model = create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape n_classes,
model_spec,
backbone,
bottleneck,
dropout_rate,
input_shape,
pre_train,
) )
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy" from_logits=True, name="cross_entropy"
) )
pre_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) # Compile the Cross-entropy model if the case
if pre_train:
pre_model.compile(
optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"],
)
arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
...@@ -252,6 +272,7 @@ def train_and_evaluate( ...@@ -252,6 +272,7 @@ def train_and_evaluate(
face_size, face_size,
validation_path, validation_path,
lerning_rate_schedule, lerning_rate_schedule,
pre_train=False,
): ):
# number of training steps to do before validating a model. This also defines an epoch # number of training steps to do before validating a model. This also defines an epoch
...@@ -300,6 +321,7 @@ def train_and_evaluate( ...@@ -300,6 +321,7 @@ def train_and_evaluate(
bottleneck=bottleneck, bottleneck=bottleneck,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
input_shape=OUTPUT_SHAPE + (3,), input_shape=OUTPUT_SHAPE + (3,),
pre_train=pre_train,
) )
def scheduler(epoch, lr): def scheduler(epoch, lr):
...@@ -312,8 +334,10 @@ def train_and_evaluate( ...@@ -312,8 +334,10 @@ def train_and_evaluate(
if epoch in range(200): if epoch in range(200):
return 1 * lr return 1 * lr
elif epoch < 1000:
return lr * np.exp(-0.005)
else: else:
return lr * tf.math.exp(-0.01) return 0.0001
if lerning_rate_schedule == "cosine-decay-restarts": if lerning_rate_schedule == "cosine-decay-restarts":
decay_steps = 50 decay_steps = 50
...@@ -339,16 +363,19 @@ def train_and_evaluate( ...@@ -339,16 +363,19 @@ def train_and_evaluate(
} }
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup") callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
# STEPS_PER_EPOCH
pre_model.fit( # Train the Cross-entropy model if the case
train_ds, if pre_train:
epochs=2, # STEPS_PER_EPOCH
validation_data=val_ds, pre_model.fit(
steps_per_epoch=STEPS_PER_EPOCH, train_ds,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE, epochs=2,
callbacks=callbacks, validation_data=val_ds,
verbose=2, steps_per_epoch=STEPS_PER_EPOCH,
) validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
# STEPS_PER_EPOCH # STEPS_PER_EPOCH
# epochs=epochs * KERAS_EPOCH_MULTIPLIER, # epochs=epochs * KERAS_EPOCH_MULTIPLIER,
...@@ -404,6 +431,9 @@ if __name__ == "__main__": ...@@ -404,6 +431,9 @@ if __name__ == "__main__":
dropout_rate=float(config["dropout-rate"]), dropout_rate=float(config["dropout-rate"]),
face_size=int(config["face-size"]), face_size=int(config["face-size"]),
validation_path=config["validation-tf-record-path"], validation_path=config["validation-tf-record-path"],
lerning_rate_schedule=config["lerning-rate-schedule"], lerning_rate_schedule=config["lerning-rate-schedule"]
if "lerning-rate-schedule" in config
else None,
pre_train=args["--pre-train"],
) )
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment