diff --git a/bob/bio/face/config/baseline/helpers.py b/bob/bio/face/config/baseline/helpers.py index 49e7c3ecb29f3bde3e0862d81fac7cc94424a14f..ffeb149681712008274aac8b054202e3a80ff9f2 100644 --- a/bob/bio/face/config/baseline/helpers.py +++ b/bob/bio/face/config/baseline/helpers.py @@ -235,7 +235,11 @@ def embedding_transformer_112x112( cropped_image_size = (112, 112) if annotation_type == "eyes-center": # Hard coding eye positions for backward consistency - cropped_positions = {"leye": (49, 72), "reye": (49, 38)} + cropped_positions = { + "leye": (55, 81), + "reye": (55, 42), + } + else: # Will use default cropped_positions = embedding_transformer_default_cropping( diff --git a/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py b/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py new file mode 100644 index 0000000000000000000000000000000000000000..e68e7ef3edf5d8c080efc9aac7bf06a39d85982b --- /dev/null +++ b/bob/bio/face/config/baseline/mobilenetv2_msceleb_arcface_2021.py @@ -0,0 +1,34 @@ +from bob.bio.face.embeddings.mobilenet_v2 import MobileNetv2_MsCeleb_ArcFace_2021 +from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112 +from bob.bio.base.pipelines.vanilla_biometrics import ( + Distance, + VanillaBiometricsPipeline, +) + +memory_demanding = False +if "database" in locals(): + annotation_type = database.annotation_type + fixed_positions = database.fixed_positions + + memory_demanding = ( + database.memory_demanding if hasattr(database, "memory_demanding") else False + ) +else: + annotation_type = None + fixed_positions = None + + +def load(annotation_type, fixed_positions=None): + transformer = embedding_transformer_112x112( + MobileNetv2_MsCeleb_ArcFace_2021(memory_demanding=memory_demanding), + annotation_type, + fixed_positions, + ) + + algorithm = Distance() + + return VanillaBiometricsPipeline(transformer, algorithm) + + +pipeline = load(annotation_type, fixed_positions) +transformer = pipeline.transformer diff --git a/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py b/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py new file mode 100644 index 0000000000000000000000000000000000000000..dfeb4b74af2958879e83c74b4428c2cb7601d308 --- /dev/null +++ b/bob/bio/face/config/baseline/resnet50_msceleb_arcface_2021.py @@ -0,0 +1,34 @@ +from bob.bio.face.embeddings.resnet50 import Resnet50_MsCeleb_ArcFace_2021 +from bob.bio.face.config.baseline.helpers import embedding_transformer_112x112 +from bob.bio.base.pipelines.vanilla_biometrics import ( + Distance, + VanillaBiometricsPipeline, +) + +memory_demanding = False +if "database" in locals(): + annotation_type = database.annotation_type + fixed_positions = database.fixed_positions + + memory_demanding = ( + database.memory_demanding if hasattr(database, "memory_demanding") else False + ) +else: + annotation_type = None + fixed_positions = None + + +def load(annotation_type, fixed_positions=None): + transformer = embedding_transformer_112x112( + Resnet50_MsCeleb_ArcFace_2021(memory_demanding=memory_demanding), + annotation_type, + fixed_positions, + ) + + algorithm = Distance() + + return VanillaBiometricsPipeline(transformer, algorithm) + + +pipeline = load(annotation_type, fixed_positions) +transformer = pipeline.transformer diff --git a/bob/bio/face/embeddings/mobilenet_v2.py b/bob/bio/face/embeddings/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4966583ff6388b44706eec199fc3e6dcb8768359 --- /dev/null +++ b/bob/bio/face/embeddings/mobilenet_v2.py @@ -0,0 +1,79 @@ +from bob.bio.face.embeddings import download_model + + +from .tf2_inception_resnet import TransformTensorflow +import pkg_resources +import os +from bob.extension import rc +import tensorflow as tf + + +class MobileNetv2_MsCeleb_ArcFace_2021(TransformTensorflow): + """ + MobileNet Backbone trained with the MSCeleb 1M database. + + The bottleneck layer (a.k.a embedding) has 512d. + + The configuration file used to trained is: + + ```yaml + batch-size: 128 + face-size: 112 + face-output_size: 112 + n-classes: 85742 + + + ## Backbone + backbone: 'mobilenet-v2' + head: 'arcface' + s: 10 + bottleneck: 512 + m: 0.5 + + # Training parameters + solver: "sgd" + lr: 0.01 + dropout-rate: 0.5 + epochs: 500 + + + train-tf-record-path: "<PATH>" + validation-tf-record-path: "<PATH>" + + ``` + + + """ + + def __init__(self, memory_demanding=False): + internal_path = pkg_resources.resource_filename( + __name__, os.path.join("data", "mobilenet-v2-msceleb-arcface-2021"), + ) + + checkpoint_path = ( + internal_path + if rc["bob.bio.face.models.mobilenet-v2-msceleb-arcface-2021"] is None + else rc["bob.bio.face.models.mobilenet-v2-msceleb-arcface-2021"] + ) + + urls = [ + "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/tf2/mobilenet-v2-msceleb-arcface-2021.tar.gz", + "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/tf2/mobilenet-v2-msceleb-arcface-2021.tar.gz", + ] + + download_model(checkpoint_path, urls, "mobilenet-v2-msceleb-arcface-2021.tar.gz") + + super(MobileNetv2_MsCeleb_ArcFace_2021, self).__init__( + checkpoint_path, + preprocessor=lambda X: X / 255.0, + memory_demanding=memory_demanding, + ) + + def inference(self, X): + if self.preprocessor is not None: + X = self.preprocessor(tf.cast(X, "float32")) + + prelogits = self.model.predict_on_batch(X)[0] + embeddings = tf.math.l2_normalize(prelogits, axis=-1) + return embeddings + diff --git a/bob/bio/face/embeddings/resnet50.py b/bob/bio/face/embeddings/resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..8a75b2ed602e12c0e2dba708ab672b37c0c4bb3b --- /dev/null +++ b/bob/bio/face/embeddings/resnet50.py @@ -0,0 +1,79 @@ +from bob.bio.face.embeddings import download_model + + +from .tf2_inception_resnet import TransformTensorflow +import pkg_resources +import os +from bob.extension import rc +import tensorflow as tf + + +class Resnet50_MsCeleb_ArcFace_2021(TransformTensorflow): + """ + Resnet50 Backbone trained with the MSCeleb 1M database. + + The bottleneck layer (a.k.a embedding) has 512d. + + The configuration file used to trained is: + + ```yaml + batch-size: 128 + face-size: 112 + face-output_size: 112 + n-classes: 85742 + + + ## Backbone + backbone: 'resnet50' + head: 'arcface' + s: 10 + bottleneck: 512 + m: 0.5 + + # Training parameters + solver: "sgd" + lr: 0.01 + dropout-rate: 0.5 + epochs: 500 + + + train-tf-record-path: "<PATH>" + validation-tf-record-path: "<PATH>" + + ``` + + + """ + + def __init__(self, memory_demanding=False): + internal_path = pkg_resources.resource_filename( + __name__, os.path.join("data", "resnet50_msceleb_arcface_2021"), + ) + + checkpoint_path = ( + internal_path + if rc["bob.bio.face.models.resnet50_msceleb_arcface_2021"] is None + else rc["bob.bio.face.models.resnet50_msceleb_arcface_2021"] + ) + + urls = [ + "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/tf2/resnet50_msceleb_arcface_2021.tar.gz", + "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/tf2/resnet50_msceleb_arcface_2021.tar.gz", + ] + + download_model(checkpoint_path, urls, "resnet50_msceleb_arcface_2021.tar.gz") + + super(Resnet50_MsCeleb_ArcFace_2021, self).__init__( + checkpoint_path, + preprocessor=lambda X: X / 255.0, + memory_demanding=memory_demanding, + ) + + def inference(self, X): + if self.preprocessor is not None: + X = self.preprocessor(tf.cast(X, "float32")) + + prelogits = self.model.predict_on_batch(X)[0] + embeddings = tf.math.l2_normalize(prelogits, axis=-1) + return embeddings + diff --git a/bob/bio/face/embeddings/tf2_inception_resnet.py b/bob/bio/face/embeddings/tf2_inception_resnet.py index 34b1c904df445c6eb0ecf4efdd9722664f41e445..b462f3a885c67153b15a4e5f930637e9ebd14cee 100644 --- a/bob/bio/face/embeddings/tf2_inception_resnet.py +++ b/bob/bio/face/embeddings/tf2_inception_resnet.py @@ -20,9 +20,9 @@ def sanderberg_rescaling(): return preprocessor -class InceptionResnet(TransformerMixin, BaseEstimator): +class TransformTensorflow(TransformerMixin, BaseEstimator): """ - Base Transformer for InceptionResnet architectures. + Base Transformer for Tensorflow architectures. Szegedy, Christian, et al. "Inception-v4, inception-resnet and the impact of residual connections on learning." arXiv preprint arXiv:1602.07261 (2016). @@ -52,14 +52,6 @@ class InceptionResnet(TransformerMixin, BaseEstimator): def load_model(self): self.model = tf.keras.models.load_model(self.checkpoint_path) - def inference(self, X): - if self.preprocessor is not None: - X = self.preprocessor(tf.cast(X, "float32")) - - prelogits = self.model.predict_on_batch(X) - embeddings = tf.math.l2_normalize(prelogits, axis=-1) - return embeddings - def transform(self, X): def _transform(X): X = tf.convert_to_tensor(X) @@ -88,6 +80,14 @@ class InceptionResnet(TransformerMixin, BaseEstimator): d["model"] = None return d + def inference(self, X): + if self.preprocessor is not None: + X = self.preprocessor(tf.cast(X, "float32")) + + prelogits = self.model.predict_on_batch(X) + embeddings = tf.math.l2_normalize(prelogits, axis=-1) + return embeddings + def _more_tags(self): return {"stateless": True, "requires_fit": False} @@ -95,7 +95,7 @@ class InceptionResnet(TransformerMixin, BaseEstimator): self.model = None -class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet): +class InceptionResnetv2_MsCeleb_CenterLoss_2018(TransformTensorflow): """ InceptionResnet v2 model trained in 2018 using the MSCeleb dataset in the context of the work: @@ -131,7 +131,7 @@ class InceptionResnetv2_MsCeleb_CenterLoss_2018(InceptionResnet): ) -class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet): +class InceptionResnetv2_Casia_CenterLoss_2018(TransformTensorflow): """ InceptionResnet v2 model trained in 2018 using the CasiaWebFace dataset in the context of the work: @@ -166,7 +166,7 @@ class InceptionResnetv2_Casia_CenterLoss_2018(InceptionResnet): ) -class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet): +class InceptionResnetv1_Casia_CenterLoss_2018(TransformTensorflow): """ InceptionResnet v1 model trained in 2018 using the CasiaWebFace dataset in the context of the work: @@ -201,7 +201,7 @@ class InceptionResnetv1_Casia_CenterLoss_2018(InceptionResnet): ) -class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet): +class InceptionResnetv1_MsCeleb_CenterLoss_2018(TransformTensorflow): """ InceptionResnet v1 model trained in 2018 using the MsCeleb dataset in the context of the work: @@ -237,7 +237,7 @@ class InceptionResnetv1_MsCeleb_CenterLoss_2018(InceptionResnet): ) -class FaceNetSanderberg_20170512_110547(InceptionResnet): +class FaceNetSanderberg_20170512_110547(TransformTensorflow): """ Wrapper for the free FaceNet from David Sanderberg model 20170512_110547: https://github.com/davidsandberg/facenet diff --git a/bob/bio/face/tensorflow/preprocessing.py b/bob/bio/face/tensorflow/preprocessing.py index bf87921c171140bb2f364bae7715da744fc4c8b5..4b92f27f44c427dcd56328356acce41554d7b86c 100644 --- a/bob/bio/face/tensorflow/preprocessing.py +++ b/bob/bio/face/tensorflow/preprocessing.py @@ -42,9 +42,10 @@ def get_preprocessor(output_shape): layers.experimental.preprocessing.RandomFlip("horizontal"), # FIXED_STANDARDIZATION from https://github.com/davidsandberg/facenet # [-0.99609375, 0.99609375] - layers.experimental.preprocessing.Rescaling( - scale=1 / 128, offset=-127.5 / 128 - ), + # layers.experimental.preprocessing.Rescaling( + # scale=1 / 128, offset=-127.5 / 128 + # ), + layers.experimental.preprocessing.Rescaling(scale=1 / 255, offset=0), ] ) return preprocessor @@ -100,7 +101,9 @@ def prepare_dataset( ignore_order = tf.data.Options() ignore_order.experimental_deterministic = False ds = ds.with_options(ignore_order) - ds = ds.map(partial(decode_tfrecords, data_shape=data_shape)).prefetch(buffer_size=autotune) + ds = ds.map(partial(decode_tfrecords, data_shape=data_shape)).prefetch( + buffer_size=autotune + ) if shuffle: ds = ds.shuffle(shuffle_buffer).repeat(epochs) preprocessor = get_preprocessor(output_shape) diff --git a/cnn_training/arcface.py b/cnn_training/arcface.py index fb59350947ca4cfc8a0f8bc7eb171d779f7b536b..712cf6dca83019dae9adea274a5622b491a295b9 100644 --- a/cnn_training/arcface.py +++ b/cnn_training/arcface.py @@ -52,11 +52,12 @@ validation-tf-record-path: "/path/lfw_pairs.tfrecord" Usage: - arcface.py <config-yaml> <checkpoint_path> + arcface.py <config-yaml> <checkpoint_path> [--pre-train] arcface.py -h | --help Options: -h --help Show this screen. + --pre-train If set pretrains the CNN with the crossentropy softmax for 2 epochs arcface.py arcface -h | help """ @@ -67,6 +68,7 @@ from functools import partial import pkg_resources import tensorflow as tf from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2 +from bob.learn.tensorflow.models import resnet50v1 from bob.learn.tensorflow.metrics import predict_using_tensors from tensorflow.keras import layers from bob.learn.tensorflow.callbacks import add_backup_callback @@ -99,6 +101,9 @@ BACKBONES["inception-resnet-v2"] = InceptionResNetV2 BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0 BACKBONES["resnet50"] = tf.keras.applications.ResNet50 BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2 +# from bob.learn.tensorflow.models.lenet5 import LeNet5_simplified + +BACKBONES["resnet50v1"] = resnet50v1 ############################## # SOLVER SPECIFICATIONS @@ -150,7 +155,7 @@ VALIDATION_BATCH_SIZE = 38 def create_model( - n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape + n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape, pre_train ): if backbone == "inception-resnet-v2": @@ -166,27 +171,34 @@ def create_model( pre_model = add_bottleneck( pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate ) - pre_model = add_top(pre_model, n_classes=n_classes) - float32_layer = layers.Activation("linear", dtype="float32") + embeddings = pre_model.get_layer("embeddings").output + labels = tf.keras.layers.Input([], name="label") - embeddings = tf.nn.l2_normalize( - pre_model.get_layer("embeddings/BatchNorm").output, axis=1 - ) + if pre_train: + # pre_model = add_top(pre_model, n_classes=n_classes) + logits_premodel = ArcFaceLayer( + n_classes, + s=model_spec["arcface"]["s"], + m=model_spec["arcface"]["m"], + arc=False, + )(embeddings, None) - logits_premodel = float32_layer(pre_model.get_layer("logits").output) + # Wrapping the embedding validation + # logits_premodel = pre_model.get_layer("logits").output - # Wrapping the embedding validation - pre_model = EmbeddingValidation( - pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name - ) + pre_model = EmbeddingValidation( + pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name + ) ################################ ## Creating the specific models if "arcface" in model_spec: - labels = tf.keras.layers.Input([], name="label") logits_arcface = ArcFaceLayer( - n_classes, s=model_spec["arcface"]["s"], m=model_spec["arcface"]["m"] + n_classes, + s=model_spec["arcface"]["s"], + m=model_spec["arcface"]["m"], + arc=True, )(embeddings, labels) arc_model = ArcFaceModel( inputs=(pre_model.input, labels), outputs=[logits_arcface, embeddings] @@ -221,19 +233,38 @@ def create_model( def build_and_compile_models( - n_classes, optimizer, model_spec, backbone, bottleneck, dropout_rate, input_shape + n_classes, + optimizer, + model_spec, + backbone, + bottleneck, + dropout_rate, + input_shape, + pre_train, ): pre_model, arc_model = create_model( - n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape + n_classes, + model_spec, + backbone, + bottleneck, + dropout_rate, + input_shape, + pre_train, ) cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, name="cross_entropy" ) - pre_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) + # Compile the Cross-entropy model if the case + if pre_train: + pre_model.compile( + optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"], + ) - arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"]) + arc_model.compile( + optimizer=optimizer, loss=cross_entropy, run_eagerly=True, metrics=["accuracy"] + ) return pre_model, arc_model @@ -252,6 +283,7 @@ def train_and_evaluate( face_size, validation_path, lerning_rate_schedule, + pre_train=False, ): # number of training steps to do before validating a model. This also defines an epoch @@ -300,6 +332,7 @@ def train_and_evaluate( bottleneck=bottleneck, dropout_rate=dropout_rate, input_shape=OUTPUT_SHAPE + (3,), + pre_train=pre_train, ) def scheduler(epoch, lr): @@ -312,8 +345,10 @@ def train_and_evaluate( if epoch in range(200): return 1 * lr + elif epoch < 1000: + return lr * np.exp(-0.005) else: - return lr * tf.math.exp(-0.01) + return 0.0001 if lerning_rate_schedule == "cosine-decay-restarts": decay_steps = 50 @@ -339,16 +374,22 @@ def train_and_evaluate( } callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup") - # STEPS_PER_EPOCH - pre_model.fit( - train_ds, - epochs=2, - validation_data=val_ds, - steps_per_epoch=STEPS_PER_EPOCH, - validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE, - callbacks=callbacks, - verbose=2, - ) + + # Train the Cross-entropy model if the case + + if pre_train: + # STEPS_PER_EPOCH + pre_model.fit( + train_ds, + epochs=20, + validation_data=val_ds, + steps_per_epoch=STEPS_PER_EPOCH, + validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE, + callbacks=callbacks, + verbose=2, + ) + # Copying the last variable + arc_model.trainable_variables[-1].assign(pre_model.trainable_variables[-1]) # STEPS_PER_EPOCH # epochs=epochs * KERAS_EPOCH_MULTIPLIER, @@ -404,6 +445,9 @@ if __name__ == "__main__": dropout_rate=float(config["dropout-rate"]), face_size=int(config["face-size"]), validation_path=config["validation-tf-record-path"], - lerning_rate_schedule=config["lerning-rate-schedule"], + lerning_rate_schedule=config["lerning-rate-schedule"] + if "lerning-rate-schedule" in config + else None, + pre_train=args["--pre-train"], ) diff --git a/setup.py b/setup.py index daf7cdb02e6053673d10f8343262a3103b686a8e..34bcf0e25d3e179309f3962b9b60602ea80c265a 100644 --- a/setup.py +++ b/setup.py @@ -146,6 +146,8 @@ setup( "lgbphs = bob.bio.face.config.baseline.lgbphs:pipeline", "lda = bob.bio.face.config.baseline.lda:pipeline", "dummy = bob.bio.face.config.baseline.dummy:pipeline", + "resnet50-msceleb-arcface-2021 = bob.bio.face.config.baseline.resnet50_msceleb_arcface_2021:pipeline", + "mobilenetv2-msceleb-arcface-2021 = bob.bio.face.config.baseline.mobilenetv2_msceleb_arcface_2021", ], "bob.bio.config": [ "facenet-sanderberg = bob.bio.face.config.baseline.facenet_sanderberg", @@ -174,6 +176,8 @@ setup( "fargo = bob.bio.face.config.database.fargo", "meds = bob.bio.face.config.database.meds", "morph = bob.bio.face.config.database.morph", + "resnet50-msceleb-arcface-2021 = bob.bio.face.config.baseline.resnet50_msceleb_arcface_2021", + "mobilenetv2-msceleb-arcface-2021 = bob.bio.face.config.baseline.mobilenetv2_msceleb_arcface_2021", ], "bob.bio.cli": [ "display-face-annotations = bob.bio.face.script.display_face_annotations:display_face_annotations",