#!/usr/bin/env python # coding: utf-8 """ Trains a face recognition CNN using the strategy from the paper "A Discriminative Feature Learning Approach for Deep Face Recognition" https://ydwen.github.io/papers/WenECCV16.pdf The default backbone is the InceptionResnetv2 Do `./bin/python centerloss.py --help` for more information """ import os from functools import partial import click import pkg_resources import tensorflow as tf from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2 from bob.learn.tensorflow.metrics import predict_using_tensors from tensorflow.keras import layers from bob.learn.tensorflow.callbacks import add_backup_callback from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings from bob.extension import rc from bob.bio.face.tensorflow.preprocessing import prepare_dataset # CNN Backbone # Change your NN backbone here BACKBONE = InceptionResNetV2 # SHAPES EXPECTED FROM THE DATASET USING THIS BACKBONE DATA_SHAPE = (182, 182, 3) # size of faces DATA_TYPE = tf.uint8 OUTPUT_SHAPE = (160, 160) AUTOTUNE = tf.data.experimental.AUTOTUNE # HERE WE VALIDATE WITH LFW RUNNING A # INFORMATION ABOUT THE VALIDATION SET VALIDATION_TF_RECORD_PATHS = rc["bob.bio.face.cnn.lfw_tfrecord_path"] # there are 2812 samples in the validation set VALIDATION_SAMPLES = 2812 VALIDATION_BATCH_SIZE = 38 # WEIGHTS BEWTWEEN the two losses LOSS_WEIGHTS = {"cross_entropy": 1.0, "center_loss": 0.01} class CenterLossModel(tf.keras.Model): def compile( self, cross_entropy, center_loss, loss_weights, train_loss, train_cross_entropy, train_center_loss, test_acc, **kwargs, ): super().compile(**kwargs) self.cross_entropy = cross_entropy self.center_loss = center_loss self.loss_weights = loss_weights self.train_loss = train_loss self.train_cross_entropy = train_cross_entropy self.train_center_loss = train_center_loss self.test_acc = test_acc def train_step(self, data): images, labels = data with tf.GradientTape() as tape: logits, prelogits = self(images, training=True) loss_cross = self.cross_entropy(labels, logits) loss_center = self.center_loss(labels, prelogits) loss = ( loss_cross * self.loss_weights[self.cross_entropy.name] + loss_center * self.loss_weights[self.center_loss.name] ) trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) self.train_loss(loss) self.train_cross_entropy(loss_cross) self.train_center_loss(loss_center) return { m.name: m.result() for m in [self.train_loss, self.train_cross_entropy, self.train_center_loss] } def test_step(self, data): images, labels = data logits, prelogits = self(images, training=False) self.test_acc(accuracy_from_embeddings(labels, prelogits)) return {m.name: m.result() for m in [self.test_acc]} def create_model(n_classes): model = BACKBONE( include_top=True, classes=n_classes, bottleneck=True, input_shape=OUTPUT_SHAPE + (3,), ) prelogits = model.get_layer("Bottleneck/BatchNorm").output prelogits = CenterLossLayer( n_classes=n_classes, n_features=prelogits.shape[-1], name="centers" )(prelogits) logits = model.get_layer("logits").output model = CenterLossModel( inputs=model.input, outputs=[logits, prelogits], name=model.name ) return model def build_and_compile_model(n_classes, learning_rate): model = create_model(n_classes) cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, name="cross_entropy" ) center_loss = CenterLoss( centers_layer=model.get_layer("centers"), alpha=0.9, name="center_loss", ) optimizer = tf.keras.optimizers.RMSprop( learning_rate=learning_rate, rho=0.9, momentum=0.9, epsilon=1.0 ) train_loss = tf.keras.metrics.Mean(name="loss") train_cross_entropy = tf.keras.metrics.Mean(name="cross_entropy") train_center_loss = tf.keras.metrics.Mean(name="center_loss") test_acc = tf.keras.metrics.Mean(name="accuracy") model.compile( optimizer=optimizer, cross_entropy=cross_entropy, center_loss=center_loss, loss_weights=LOSS_WEIGHTS, train_loss=train_loss, train_cross_entropy=train_cross_entropy, train_center_loss=train_center_loss, test_acc=test_acc, ) return model @click.command() @click.argument("tf-record-paths") @click.argument("checkpoint-path") @click.option( "-n", "--n-classes", default=87662, help="Number of classes in the classification problem. Default to `87662`, which is the number of identities in our pruned MSCeleb", ) @click.option( "-b", "--batch-size", default=90, help="Batch size. Be aware that we are using single precision. Batch size should be high.", ) @click.option( "-e", "--epochs", default=35, help="Number of epochs", ) def train_and_evaluate(tf_record_paths, checkpoint_path, n_classes, batch_size, epochs): # number of training steps to do before validating a model. This also defines an epoch # for keras which is not really true. We want to evaluate every 180000 (90 * 2000) # samples STEPS_PER_EPOCH = 180000 // batch_size learning_rate = 0.1 KERAS_EPOCH_MULTIPLIER = 6 train_ds = prepare_dataset( tf_record_paths, batch_size, epochs, data_shape=DATA_SHAPE, output_shape=OUTPUT_SHAPE, shuffle=True, augment=True, ) if VALIDATION_TF_RECORD_PATHS is None: raise ValueError( "No validation set was set. Please, do `bob config set bob.bio.face.cnn.lfw_tfrecord_path [PATH]`" ) val_ds = prepare_dataset( VALIDATION_TF_RECORD_PATHS, data_shape=DATA_SHAPE, output_shape=OUTPUT_SHAPE, epochs=epochs, batch_size=VALIDATION_BATCH_SIZE, shuffle=False, augment=False, ) val_metric_name = "val_accuracy" model = build_and_compile_model(n_classes, learning_rate) def scheduler(epoch, lr): # 20 epochs at 0.1, 10 at 0.01 and 5 0.001 # The epoch number here is Keras's which is different from actual epoch number epoch = epoch // KERAS_EPOCH_MULTIPLIER if epoch in range(20): return 0.1 elif epoch in range(20, 30): return 0.01 else: return 0.001 callbacks = { "latest": tf.keras.callbacks.ModelCheckpoint( f"{checkpoint_path}/latest", verbose=1 ), "best": tf.keras.callbacks.ModelCheckpoint( f"{checkpoint_path}/best", monitor=val_metric_name, save_best_only=True, mode="max", verbose=1, ), "tensorboard": tf.keras.callbacks.TensorBoard( log_dir=f"{checkpoint_path}/logs", update_freq=15, profile_batch=0 ), "lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1), "nan": tf.keras.callbacks.TerminateOnNaN(), } callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup") model.fit( train_ds, validation_data=val_ds, epochs=epochs * KERAS_EPOCH_MULTIPLIER, steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE, callbacks=callbacks, verbose=2, ) if __name__ == "__main__": train_and_evaluate()