Skip to content
Snippets Groups Projects
Commit 8a322bcc authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Improvements in the arcface trainer

parent 0e02e70f
No related branches found
No related tags found
1 merge request!102New baselines
...@@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord" ...@@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord"
Usage: Usage:
arcface.py <config-yaml> <checkpoint_path> arcface.py <config-yaml> <checkpoint_path> [--pre-train]
arcface.py -h | --help arcface.py -h | --help
Options: Options:
-h --help Show this screen. -h --help Show this screen.
--pre-train If set pretrains the CNN with the crossentropy softmax for 2 epochs
arcface.py arcface -h | help arcface.py arcface -h | help
""" """
...@@ -67,6 +68,7 @@ from functools import partial ...@@ -67,6 +68,7 @@ from functools import partial
import pkg_resources import pkg_resources
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2 from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.models import resnet50v1
from bob.learn.tensorflow.metrics import predict_using_tensors from bob.learn.tensorflow.metrics import predict_using_tensors
from tensorflow.keras import layers from tensorflow.keras import layers
from bob.learn.tensorflow.callbacks import add_backup_callback from bob.learn.tensorflow.callbacks import add_backup_callback
...@@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2 ...@@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2
BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0 BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0
BACKBONES["resnet50"] = tf.keras.applications.ResNet50 BACKBONES["resnet50"] = tf.keras.applications.ResNet50
BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2 BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2
# from bob.learn.tensorflow.models.lenet5 import LeNet5_simplified
BACKBONES["resnet50v1"] = resnet50v1
############################## ##############################
# SOLVER SPECIFICATIONS # SOLVER SPECIFICATIONS
...@@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38 ...@@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38
def create_model( def create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape, pre_train
): ):
if backbone == "inception-resnet-v2": if backbone == "inception-resnet-v2":
...@@ -166,20 +171,18 @@ def create_model( ...@@ -166,20 +171,18 @@ def create_model(
pre_model = add_bottleneck( pre_model = add_bottleneck(
pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate
) )
pre_model = add_top(pre_model, n_classes=n_classes)
float32_layer = layers.Activation("linear", dtype="float32") embeddings = pre_model.get_layer("embeddings").output
embeddings = tf.nn.l2_normalize( if pre_train:
pre_model.get_layer("embeddings/BatchNorm").output, axis=1 pre_model = add_top(pre_model, n_classes=n_classes)
)
logits_premodel = float32_layer(pre_model.get_layer("logits").output) # Wrapping the embedding validation
logits_premodel = pre_model.get_layer("logits").output
# Wrapping the embedding validation pre_model = EmbeddingValidation(
pre_model = EmbeddingValidation( pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name )
)
################################ ################################
## Creating the specific models ## Creating the specific models
...@@ -221,17 +224,34 @@ def create_model( ...@@ -221,17 +224,34 @@ def create_model(
def build_and_compile_models( def build_and_compile_models(
n_classes, optimizer, model_spec, backbone, bottleneck, dropout_rate, input_shape n_classes,
optimizer,
model_spec,
backbone,
bottleneck,
dropout_rate,
input_shape,
pre_train,
): ):
pre_model, arc_model = create_model( pre_model, arc_model = create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape n_classes,
model_spec,
backbone,
bottleneck,
dropout_rate,
input_shape,
pre_train,
) )
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy" from_logits=True, name="cross_entropy"
) )
pre_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) # Compile the Cross-entropy model if the case
if pre_train:
pre_model.compile(
optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"],
)
arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
...@@ -252,6 +272,7 @@ def train_and_evaluate( ...@@ -252,6 +272,7 @@ def train_and_evaluate(
face_size, face_size,
validation_path, validation_path,
lerning_rate_schedule, lerning_rate_schedule,
pre_train=False,
): ):
# number of training steps to do before validating a model. This also defines an epoch # number of training steps to do before validating a model. This also defines an epoch
...@@ -300,6 +321,7 @@ def train_and_evaluate( ...@@ -300,6 +321,7 @@ def train_and_evaluate(
bottleneck=bottleneck, bottleneck=bottleneck,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
input_shape=OUTPUT_SHAPE + (3,), input_shape=OUTPUT_SHAPE + (3,),
pre_train=pre_train,
) )
def scheduler(epoch, lr): def scheduler(epoch, lr):
...@@ -312,8 +334,10 @@ def train_and_evaluate( ...@@ -312,8 +334,10 @@ def train_and_evaluate(
if epoch in range(200): if epoch in range(200):
return 1 * lr return 1 * lr
elif epoch < 1000:
return lr * np.exp(-0.005)
else: else:
return lr * tf.math.exp(-0.01) return 0.0001
if lerning_rate_schedule == "cosine-decay-restarts": if lerning_rate_schedule == "cosine-decay-restarts":
decay_steps = 50 decay_steps = 50
...@@ -339,16 +363,19 @@ def train_and_evaluate( ...@@ -339,16 +363,19 @@ def train_and_evaluate(
} }
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup") callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
# STEPS_PER_EPOCH
pre_model.fit( # Train the Cross-entropy model if the case
train_ds, if pre_train:
epochs=2, # STEPS_PER_EPOCH
validation_data=val_ds, pre_model.fit(
steps_per_epoch=STEPS_PER_EPOCH, train_ds,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE, epochs=2,
callbacks=callbacks, validation_data=val_ds,
verbose=2, steps_per_epoch=STEPS_PER_EPOCH,
) validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
# STEPS_PER_EPOCH # STEPS_PER_EPOCH
# epochs=epochs * KERAS_EPOCH_MULTIPLIER, # epochs=epochs * KERAS_EPOCH_MULTIPLIER,
...@@ -404,6 +431,9 @@ if __name__ == "__main__": ...@@ -404,6 +431,9 @@ if __name__ == "__main__":
dropout_rate=float(config["dropout-rate"]), dropout_rate=float(config["dropout-rate"]),
face_size=int(config["face-size"]), face_size=int(config["face-size"]),
validation_path=config["validation-tf-record-path"], validation_path=config["validation-tf-record-path"],
lerning_rate_schedule=config["lerning-rate-schedule"], lerning_rate_schedule=config["lerning-rate-schedule"]
if "lerning-rate-schedule" in config
else None,
pre_train=args["--pre-train"],
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment