Skip to content
Snippets Groups Projects
Commit d4021c32 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'tf-training' into 'master'

TF2 training

See merge request !96
parents 1a99ee82 163c02d1
No related branches found
No related tags found
1 merge request!96TF2 training
Pipeline #47861 failed
#!/usr/bin/env python
# coding: utf-8
"""
Tensor pre-processing for somr face recognition CNNs
"""
import os
from functools import partial
from multiprocessing import cpu_count
import tensorflow as tf
from tensorflow.keras import layers
# STANDARD FEATURES FROM OUR TF-RECORDS
FEATURES = {
"data": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.int64),
"key": tf.io.FixedLenFeature([], tf.string),
}
def decode_tfrecords(x, data_shape, data_type=tf.uint8):
features = tf.io.parse_single_example(x, FEATURES)
image = tf.io.decode_raw(features["data"], data_type)
image = tf.reshape(image, data_shape)
features["data"] = image
return features
def get_preprocessor(output_shape):
"""
"""
preprocessor = tf.keras.Sequential(
[
# rotate before cropping
# 5 random degree rotation
layers.experimental.preprocessing.RandomRotation(5 / 360),
layers.experimental.preprocessing.RandomCrop(
height=output_shape[0], width=output_shape[1]
),
layers.experimental.preprocessing.RandomFlip("horizontal"),
# FIXED_STANDARDIZATION from https://github.com/davidsandberg/facenet
# [-0.99609375, 0.99609375]
layers.experimental.preprocessing.Rescaling(
scale=1 / 128, offset=-127.5 / 128
),
]
)
return preprocessor
def preprocess(preprocessor, features, augment=False):
image = features["data"]
label = features["label"]
image = preprocessor(image, training=augment)
return image, label
def prepare_dataset(
tf_record_paths,
batch_size,
epochs,
data_shape,
output_shape,
shuffle=False,
augment=False,
autotune=tf.data.experimental.AUTOTUNE,
n_cpus=cpu_count(),
shuffle_buffer=int(2e4),
):
"""
Create batches from a list of TF-Records
Parameters
----------
tf_record_paths: list
List of paths of the TF-Records
batch_size: int
epochs: int
shuffle: bool
augment: bool
autotune: int
n_cpus: int
shuffle_buffer: int
"""
ds = tf.data.Dataset.list_files(tf_record_paths, shuffle=shuffle)
ds = tf.data.TFRecordDataset(ds, num_parallel_reads=n_cpus)
if shuffle:
# ignore order and read files as soon as they come in
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
ds = ds.with_options(ignore_order)
ds = ds.map(partial(decode_tfrecords, data_shape=data_shape)).prefetch(buffer_size=autotune)
if shuffle:
ds = ds.shuffle(shuffle_buffer).repeat(epochs)
preprocessor = get_preprocessor(output_shape)
ds = ds.batch(batch_size).map(
partial(preprocess, preprocessor, augment=augment), num_parallel_calls=autotune,
)
# Use buffered prefecting on all datasets
return ds.prefetch(buffer_size=autotune)
File added
#!/usr/bin/env python
# coding: utf-8
"""
Trains some face recognition baselines using ARC based models
# ARCFACE PARAMETERS from eq.4
# FROM https://github.com/deepinsight/insightface/blob/master/recognition/ArcFace/sample_config.py#L153
M1 = 1.0
M2 = 0.3
M3 = 0.2
# ARCFACE PARAMETERS from eq.3
M = 0.5 # ArcFace Margin #CHECK SECTION 3.1
SCALE = 64.0 # Scale
# ORIGINAL = False # Original implementation
The config file has the following format to train an ARCFACE model:
```yml
# VGG2 params
batch-size: 90
face-size: 182
face-output-size: 160
n-classes: 87662
## Backbone
backbone: 'mobilenet-v2'
head: 'arcface'
s: 10
bottleneck: 512
m: 0.5
# Training parameters
#solver: "rmsprop"
solver: "sgd"
lr: 0.1
dropout-rate: 0.5
epochs: 310
lerning-rate-schedule: 'cosine-decay-restarts'
train-tf-record-path: "/path/*.tfrecord"
validation-tf-record-path: "/path/lfw_pairs.tfrecord"
```
Usage:
arcface.py <config-yaml> <checkpoint_path>
arcface.py -h | --help
Options:
-h --help Show this screen.
arcface.py arcface -h | help
"""
import os
from functools import partial
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.metrics import predict_using_tensors
from tensorflow.keras import layers
from bob.learn.tensorflow.callbacks import add_backup_callback
from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
from bob.extension import rc
from bob.bio.face.tensorflow.preprocessing import prepare_dataset
import yaml
from bob.learn.tensorflow.layers import (
add_bottleneck,
add_top,
SphereFaceLayer,
ModifiedSoftMaxLayer,
)
from bob.learn.tensorflow.models import (
EmbeddingValidation,
ArcFaceLayer,
ArcFaceModel,
ArcFaceLayer3Penalties,
)
##############################
# CNN Backbones
# Add your NN backbone here
##############################
BACKBONES = dict()
BACKBONES["inception-resnet-v2"] = InceptionResNetV2
BACKBONES["efficientnet-B0"] = tf.keras.applications.EfficientNetB0
BACKBONES["resnet50"] = tf.keras.applications.ResNet50
BACKBONES["mobilenet-v2"] = tf.keras.applications.MobileNetV2
##############################
# SOLVER SPECIFICATIONS
##############################
SOLVERS = dict()
# Parameters taken from https://github.com/davidsandberg/facenet/blob/master/src/facenet.py#L181
# Fixing the start learning rate
learning_rate = 0.1
SOLVERS["rmsprop"] = partial(
tf.keras.optimizers.RMSprop,
learning_rate=learning_rate,
rho=0.9,
momentum=0.9,
epsilon=1.0,
)
SOLVERS["adam"] = partial(tf.keras.optimizers.Adam, learning_rate=learning_rate)
SOLVERS["adagrad"] = partial(tf.keras.optimizers.Adagrad, learning_rate=learning_rate)
SOLVERS["sgd"] = partial(
tf.keras.optimizers.SGD, learning_rate=learning_rate, momentum=0.9, nesterov=True
)
################################
# DATA SPECIFICATION
###############################
DATA_SHAPES = dict()
# Inputs with 182x182 are cropped to 160x160
DATA_SHAPES[182] = 160
DATA_SHAPES[112] = 98
DATA_SHAPES[126] = 112
# SHAPES EXPECTED FROM THE DATASET USING THIS BACKBONE
# DATA_SHAPE = (182, 182, 3) # size of faces
DATA_TYPE = tf.uint8
# OUTPUT_SHAPE = (160, 160)
AUTOTUNE = tf.data.experimental.AUTOTUNE
# HERE WE VALIDATE WITH LFW RUNNING A
# INFORMATION ABOUT THE VALIDATION SET
# VALIDATION_TF_RECORD_PATHS = rc["bob.bio.face.cnn.lfw_tfrecord_path"]
# there are 2812 samples in the validation set
VALIDATION_SAMPLES = 2812
VALIDATION_BATCH_SIZE = 38
def create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape
):
if backbone == "inception-resnet-v2":
pre_model = BACKBONES[backbone](
include_top=False, bottleneck=False, input_shape=input_shape,
)
else:
pre_model = BACKBONES[backbone](
include_top=False, input_shape=input_shape, weights=None,
)
# Adding the bottleneck
pre_model = add_bottleneck(
pre_model, bottleneck_size=bottleneck, dropout_rate=dropout_rate
)
pre_model = add_top(pre_model, n_classes=n_classes)
float32_layer = layers.Activation("linear", dtype="float32")
embeddings = tf.nn.l2_normalize(
pre_model.get_layer("embeddings/BatchNorm").output, axis=1
)
logits_premodel = float32_layer(pre_model.get_layer("logits").output)
# Wrapping the embedding validation
pre_model = EmbeddingValidation(
pre_model.input, outputs=[logits_premodel, embeddings], name=pre_model.name
)
################################
## Creating the specific models
if "arcface" in model_spec:
labels = tf.keras.layers.Input([], name="label")
logits_arcface = ArcFaceLayer(
n_classes, s=model_spec["arcface"]["s"], m=model_spec["arcface"]["m"]
)(embeddings, labels)
arc_model = ArcFaceModel(
inputs=(pre_model.input, labels), outputs=[logits_arcface, embeddings]
)
elif "arcface-3p" in model_spec:
labels = tf.keras.layers.Input([], name="label")
logits_arcface = ArcFaceLayer3Penalties(
n_classes,
s=model_spec["arcface-3p"]["s"],
m1=model_spec["arcface-3p"]["m1"],
m2=model_spec["arcface-3p"]["m2"],
m3=model_spec["arcface-3p"]["m3"],
)(embeddings, labels)
arc_model = ArcFaceModel(
inputs=(pre_model.input, labels), outputs=[logits_arcface, embeddings]
)
elif "sphereface" in model_spec:
logits_arcface = SphereFaceLayer(n_classes, m=model_spec["sphereface"]["m"],)(
embeddings
)
arc_model = EmbeddingValidation(
pre_model.input, outputs=[logits_arcface, embeddings]
)
elif "modified-softmax" in model_spec:
logits_modified_softmax = ModifiedSoftMaxLayer(n_classes)(embeddings)
arc_model = EmbeddingValidation(
pre_model.input, outputs=[logits_modified_softmax, embeddings]
)
return pre_model, arc_model
def build_and_compile_models(
n_classes, optimizer, model_spec, backbone, bottleneck, dropout_rate, input_shape
):
pre_model, arc_model = create_model(
n_classes, model_spec, backbone, bottleneck, dropout_rate, input_shape
)
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy"
)
pre_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
arc_model.compile(optimizer=optimizer, loss=cross_entropy, metrics=["accuracy"])
return pre_model, arc_model
def train_and_evaluate(
tf_record_paths,
checkpoint_path,
n_classes,
batch_size,
epochs,
model_spec,
backbone,
optimizer,
bottleneck,
dropout_rate,
face_size,
validation_path,
lerning_rate_schedule,
):
# number of training steps to do before validating a model. This also defines an epoch
# for keras which is not really true. We want to evaluate every 180000 (90 * 2000)
# samples
# STEPS_PER_EPOCH = 180000 // batch_size
# KERAS_EPOCH_MULTIPLIER = 6
STEPS_PER_EPOCH = 2000
DATA_SHAPE = (face_size, face_size, 3)
OUTPUT_SHAPE = (DATA_SHAPES[face_size], DATA_SHAPES[face_size])
if validation_path is None:
validation_path = rc["bob.bio.face.cnn.lfw_tfrecord_path"]
if validation_path is None:
raise ValueError(
"No validation set was set. Please, do `bob config set bob.bio.face.cnn.lfw_tfrecord_path [PATH]`"
)
train_ds = prepare_dataset(
tf_record_paths,
batch_size,
epochs,
data_shape=DATA_SHAPE,
output_shape=OUTPUT_SHAPE,
shuffle=True,
augment=True,
)
val_ds = prepare_dataset(
validation_path,
data_shape=DATA_SHAPE,
output_shape=OUTPUT_SHAPE,
epochs=epochs,
batch_size=VALIDATION_BATCH_SIZE,
shuffle=False,
augment=False,
)
val_metric_name = "val_accuracy"
pre_model, arc_model = build_and_compile_models(
n_classes,
optimizer,
model_spec,
backbone,
bottleneck=bottleneck,
dropout_rate=dropout_rate,
input_shape=OUTPUT_SHAPE + (3,),
)
def scheduler(epoch, lr):
# 200 epochs at 0.1, 10 at 0.01 and 5 0.001
# The epoch number here is Keras's which is different from actual epoch number
# epoch = epoch // KERAS_EPOCH_MULTIPLIER
# Tracking in the tensorboard
tf.summary.scalar("learning rate", data=lr, step=epoch)
if epoch in range(200):
return 1 * lr
else:
return lr * tf.math.exp(-0.01)
if lerning_rate_schedule == "cosine-decay-restarts":
decay_steps = 50
lr_decayed_fn = tf.keras.callbacks.LearningRateScheduler(
tf.keras.experimental.CosineDecayRestarts(
0.1, decay_steps, t_mul=2.0, m_mul=0.8, alpha=0.1
),
verbose=1,
)
else:
lr_decayed_fn = tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1)
callbacks = {
"latest": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/latest", verbose=1
),
"tensorboard": tf.keras.callbacks.TensorBoard(
log_dir=f"{checkpoint_path}/logs", update_freq=15, profile_batch=0
),
"lr": lr_decayed_fn,
"nan": tf.keras.callbacks.TerminateOnNaN(),
}
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
# STEPS_PER_EPOCH
pre_model.fit(
train_ds,
epochs=2,
validation_data=val_ds,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
# STEPS_PER_EPOCH
# epochs=epochs * KERAS_EPOCH_MULTIPLIER,
arc_model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
from docopt import docopt
if __name__ == "__main__":
args = docopt(__doc__)
config = yaml.full_load(open(args["<config-yaml>"]))
model_spec = dict()
if config["head"] == "arcface":
model_spec["arcface"] = dict()
model_spec["arcface"]["m"] = float(config["m"])
model_spec["arcface"]["s"] = int(config["s"])
if config["head"] == "arcface-3p":
model_spec["arcface-3p"] = dict()
model_spec["arcface-3p"]["m1"] = float(config["m1"])
model_spec["arcface-3p"]["m2"] = float(config["m2"])
model_spec["arcface-3p"]["m3"] = float(config["m3"])
model_spec["arcface-3p"]["s"] = int(config["s"])
if config["head"] == "sphereface":
model_spec["sphereface"] = dict()
model_spec["sphereface"]["m"] = float(config["m"])
if config["head"] == "modified-softmax":
# There's no hyper parameter here
model_spec["modified-softmax"] = dict()
train_and_evaluate(
config["train-tf-record-path"],
args["<checkpoint_path>"],
int(config["n-classes"]),
int(config["batch-size"]),
int(config["epochs"]),
model_spec,
config["backbone"],
optimizer=SOLVERS[config["solver"]](learning_rate=float(config["lr"])),
bottleneck=int(config["bottleneck"]),
dropout_rate=float(config["dropout-rate"]),
face_size=int(config["face-size"]),
validation_path=config["validation-tf-record-path"],
lerning_rate_schedule=config["lerning-rate-schedule"],
)
#!/usr/bin/env python
# coding: utf-8
"""
Trains a face recognition CNN using the strategy from the paper
"A Discriminative Feature Learning Approach
for Deep Face Recognition" https://ydwen.github.io/papers/WenECCV16.pdf
The default backbone is the InceptionResnetv2
Do `./bin/python centerloss.py --help` for more information
"""
import os
from functools import partial
import click
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.metrics import predict_using_tensors
from tensorflow.keras import layers
from bob.learn.tensorflow.callbacks import add_backup_callback
from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
from bob.extension import rc
from bob.bio.face.tensorflow.preprocessing import prepare_dataset
# CNN Backbone
# Change your NN backbone here
BACKBONE = InceptionResNetV2
# SHAPES EXPECTED FROM THE DATASET USING THIS BACKBONE
DATA_SHAPE = (182, 182, 3) # size of faces
DATA_TYPE = tf.uint8
OUTPUT_SHAPE = (160, 160)
AUTOTUNE = tf.data.experimental.AUTOTUNE
# HERE WE VALIDATE WITH LFW RUNNING A
# INFORMATION ABOUT THE VALIDATION SET
VALIDATION_TF_RECORD_PATHS = rc["bob.bio.face.cnn.lfw_tfrecord_path"]
# there are 2812 samples in the validation set
VALIDATION_SAMPLES = 2812
VALIDATION_BATCH_SIZE = 38
# WEIGHTS BEWTWEEN the two losses
LOSS_WEIGHTS = {"cross_entropy": 1.0, "center_loss": 0.01}
class CenterLossModel(tf.keras.Model):
def compile(
self,
cross_entropy,
center_loss,
loss_weights,
train_loss,
train_cross_entropy,
train_center_loss,
test_acc,
**kwargs,
):
super().compile(**kwargs)
self.cross_entropy = cross_entropy
self.center_loss = center_loss
self.loss_weights = loss_weights
self.train_loss = train_loss
self.train_cross_entropy = train_cross_entropy
self.train_center_loss = train_center_loss
self.test_acc = test_acc
def train_step(self, data):
images, labels = data
with tf.GradientTape() as tape:
logits, prelogits = self(images, training=True)
loss_cross = self.cross_entropy(labels, logits)
loss_center = self.center_loss(labels, prelogits)
loss = (
loss_cross * self.loss_weights[self.cross_entropy.name]
+ loss_center * self.loss_weights[self.center_loss.name]
)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.train_loss(loss)
self.train_cross_entropy(loss_cross)
self.train_center_loss(loss_center)
return {
m.name: m.result()
for m in [self.train_loss, self.train_cross_entropy, self.train_center_loss]
}
def test_step(self, data):
images, labels = data
logits, prelogits = self(images, training=False)
self.test_acc(accuracy_from_embeddings(labels, prelogits))
return {m.name: m.result() for m in [self.test_acc]}
def create_model(n_classes):
model = BACKBONE(
include_top=True,
classes=n_classes,
bottleneck=True,
input_shape=OUTPUT_SHAPE + (3,),
)
prelogits = model.get_layer("Bottleneck/BatchNorm").output
prelogits = CenterLossLayer(
n_classes=n_classes, n_features=prelogits.shape[-1], name="centers"
)(prelogits)
logits = model.get_layer("logits").output
model = CenterLossModel(
inputs=model.input, outputs=[logits, prelogits], name=model.name
)
return model
def build_and_compile_model(n_classes, learning_rate):
model = create_model(n_classes)
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy"
)
center_loss = CenterLoss(
centers_layer=model.get_layer("centers"), alpha=0.9, name="center_loss",
)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate, rho=0.9, momentum=0.9, epsilon=1.0
)
train_loss = tf.keras.metrics.Mean(name="loss")
train_cross_entropy = tf.keras.metrics.Mean(name="cross_entropy")
train_center_loss = tf.keras.metrics.Mean(name="center_loss")
test_acc = tf.keras.metrics.Mean(name="accuracy")
model.compile(
optimizer=optimizer,
cross_entropy=cross_entropy,
center_loss=center_loss,
loss_weights=LOSS_WEIGHTS,
train_loss=train_loss,
train_cross_entropy=train_cross_entropy,
train_center_loss=train_center_loss,
test_acc=test_acc,
)
return model
@click.command()
@click.argument("tf-record-paths")
@click.argument("checkpoint-path")
@click.option(
"-n",
"--n-classes",
default=87662,
help="Number of classes in the classification problem. Default to `87662`, which is the number of identities in our pruned MSCeleb",
)
@click.option(
"-b",
"--batch-size",
default=90,
help="Batch size. Be aware that we are using single precision. Batch size should be high.",
)
@click.option(
"-e", "--epochs", default=35, help="Number of epochs",
)
def train_and_evaluate(tf_record_paths, checkpoint_path, n_classes, batch_size, epochs):
# number of training steps to do before validating a model. This also defines an epoch
# for keras which is not really true. We want to evaluate every 180000 (90 * 2000)
# samples
STEPS_PER_EPOCH = 180000 // batch_size
learning_rate = 0.1
KERAS_EPOCH_MULTIPLIER = 6
train_ds = prepare_dataset(
tf_record_paths,
batch_size,
epochs,
data_shape=DATA_SHAPE,
output_shape=OUTPUT_SHAPE,
shuffle=True,
augment=True,
)
if VALIDATION_TF_RECORD_PATHS is None:
raise ValueError(
"No validation set was set. Please, do `bob config set bob.bio.face.cnn.lfw_tfrecord_path [PATH]`"
)
val_ds = prepare_dataset(
VALIDATION_TF_RECORD_PATHS,
data_shape=DATA_SHAPE,
output_shape=OUTPUT_SHAPE,
epochs=epochs,
batch_size=VALIDATION_BATCH_SIZE,
shuffle=False,
augment=False,
)
val_metric_name = "val_accuracy"
model = build_and_compile_model(n_classes, learning_rate)
def scheduler(epoch, lr):
# 20 epochs at 0.1, 10 at 0.01 and 5 0.001
# The epoch number here is Keras's which is different from actual epoch number
epoch = epoch // KERAS_EPOCH_MULTIPLIER
if epoch in range(20):
return 0.1
elif epoch in range(20, 30):
return 0.01
else:
return 0.001
callbacks = {
"latest": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/latest", verbose=1
),
"best": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/best",
monitor=val_metric_name,
save_best_only=True,
mode="max",
verbose=1,
),
"tensorboard": tf.keras.callbacks.TensorBoard(
log_dir=f"{checkpoint_path}/logs", update_freq=15, profile_batch=0
),
"lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
"nan": tf.keras.callbacks.TerminateOnNaN(),
}
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs * KERAS_EPOCH_MULTIPLIER,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
if __name__ == "__main__":
train_and_evaluate()
w #!/usr/bin/env python
# coding: utf-8
"""
Trains a face recognition CNN using the strategy from the paper
"A Discriminative Feature Learning Approach
for Deep Face Recognition" https://ydwen.github.io/papers/WenECCV16.pdf
#########
# THIS ONE USES FLOAT16 TO COMPUTE THE GRADIENTS
# CHECKE HERE FOR MORE INFO: # https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/experimental/Policy
########
The default backbone is the InceptionResnetv2
Do `./bin/python centerloss_mixed_precision.py --help` for more information
"""
import os
from functools import partial
import click
import pkg_resources
import tensorflow as tf
from bob.learn.tensorflow.losses import CenterLoss, CenterLossLayer
from bob.learn.tensorflow.models.inception_resnet_v2 import InceptionResNetV2
from bob.learn.tensorflow.metrics import predict_using_tensors
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import experimental as mixed_precision
from bob.learn.tensorflow.callbacks import add_backup_callback
from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
from bob.extension import rc
from bob.bio.face.tensorflow.preprocessing import prepare_dataset
# Setting mixed precision policy
# https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/experimental/Policy
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_policy(policy)
# CNN Backbone
# Change your NN backbone here
BACKBONE = InceptionResNetV2
# SHAPES EXPECTED FROM THE DATASET USING THIS BACKBONE
DATA_SHAPE = (182, 182, 3) # size of faces
DATA_TYPE = tf.uint8
OUTPUT_SHAPE = (160, 160)
AUTOTUNE = tf.data.experimental.AUTOTUNE
# HERE WE VALIDATE WITH LFW RUNNING A
# INFORMATION ABOUT THE VALIDATION SET
VALIDATION_TF_RECORD_PATHS = rc["bob.bio.face.cnn.lfw_tfrecord_path"]
# there are 2812 samples in the validation set
VALIDATION_SAMPLES = 2812
VALIDATION_BATCH_SIZE = 38
# WEIGHTS BEWTWEEN the two losses
LOSS_WEIGHTS = {"cross_entropy": 1.0, "center_loss": 0.01}
class CenterLossModel(tf.keras.Model):
def compile(
self,
cross_entropy,
center_loss,
loss_weights,
train_loss,
train_cross_entropy,
train_center_loss,
test_acc,
global_batch_size,
**kwargs,
):
super().compile(**kwargs)
self.cross_entropy = cross_entropy
self.center_loss = center_loss
self.loss_weights = loss_weights
self.train_loss = train_loss
self.train_cross_entropy = train_cross_entropy
self.train_center_loss = train_center_loss
self.test_acc = test_acc
self.global_batch_size = global_batch_size
def train_step(self, data):
images, labels = data
with tf.GradientTape() as tape:
logits, prelogits = self(images, training=True)
loss_cross = self.cross_entropy(labels, logits)
loss_center = self.center_loss(labels, prelogits)
loss = (
loss_cross * self.loss_weights[self.cross_entropy.name]
+ loss_center * self.loss_weights[self.center_loss.name]
)
unscaled_loss = tf.nn.compute_average_loss(
loss, global_batch_size=self.global_batch_size
)
loss = self.optimizer.get_scaled_loss(unscaled_loss)
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
gradients = self.optimizer.get_unscaled_gradients(gradients)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
self.train_loss(unscaled_loss)
self.train_cross_entropy(loss_cross)
self.train_center_loss(loss_center)
return {
m.name: m.result()
for m in [self.train_loss, self.train_cross_entropy, self.train_center_loss]
}
def test_step(self, data):
images, labels = data
logits, prelogits = self(images, training=False)
self.test_acc(accuracy_from_embeddings(labels, prelogits))
return {m.name: m.result() for m in [self.test_acc]}
def create_model(n_classes):
model = BACKBONE(
include_top=True,
classes=n_classes,
bottleneck=True,
input_shape=OUTPUT_SHAPE + (3,),
kernel_regularizer=tf.keras.regularizers.L2(5e-5),
)
float32_layer = layers.Activation("linear", dtype="float32")
prelogits = model.get_layer("Bottleneck/BatchNorm").output
prelogits = CenterLossLayer(
n_classes=n_classes, n_features=prelogits.shape[-1], name="centers"
)(prelogits)
prelogits = float32_layer(prelogits)
logits = float32_layer(model.get_layer("logits").output)
model = CenterLossModel(
inputs=model.input, outputs=[logits, prelogits], name=model.name
)
return model
def build_and_compile_model(n_classes, learning_rate, global_batch_size):
model = create_model(n_classes)
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, name="cross_entropy", reduction=tf.keras.losses.Reduction.NONE
)
center_loss = CenterLoss(
centers_layer=model.get_layer("centers"),
alpha=0.9,
name="center_loss",
reduction=tf.keras.losses.Reduction.NONE,
)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=learning_rate, rho=0.9, momentum=0.9, epsilon=1.0
)
optimizer = mixed_precision.LossScaleOptimizer(optimizer, loss_scale="dynamic")
train_loss = tf.keras.metrics.Mean(name="loss")
train_cross_entropy = tf.keras.metrics.Mean(name="cross_entropy")
train_center_loss = tf.keras.metrics.Mean(name="center_loss")
test_acc = tf.keras.metrics.Mean(name="accuracy")
model.compile(
optimizer=optimizer,
cross_entropy=cross_entropy,
center_loss=center_loss,
loss_weights=LOSS_WEIGHTS,
train_loss=train_loss,
train_cross_entropy=train_cross_entropy,
train_center_loss=train_center_loss,
test_acc=test_acc,
global_batch_size=global_batch_size,
)
return model
@click.command()
@click.argument("tf-record-paths")
@click.argument("checkpoint-path")
@click.option(
"-n",
"--n-classes",
default=87662,
help="Number of classes in the classification problem. Default to `87662`, which is the number of identities in our pruned MSCeleb",
)
@click.option(
"-b",
"--batch-size",
default=90 * 2,
help="Batch size. Be aware that we are using single precision. Batch size should be high.",
)
@click.option(
"-e", "--epochs", default=35, help="Number of epochs",
)
def train_and_evaluate(tf_record_paths, checkpoint_path, n_classes, batch_size, epochs):
# number of training steps to do before validating a model. This also defines an epoch
# for keras which is not really true. We want to evaluate every 180000 (90 * 2000)
# samples
STEPS_PER_EPOCH = 180000 // batch_size
learning_rate = 0.1
KERAS_EPOCH_MULTIPLIER = 6
train_ds = prepare_dataset(
tf_record_paths,
batch_size,
epochs,
data_shape=DATA_SHAPE,
output_shape=OUTPUT_SHAPE,
shuffle=True,
augment=True,
)
if VALIDATION_TF_RECORD_PATHS is None:
raise ValueError(
"No validation set was set. Please, do `bob config set bob.bio.face.cnn.lfw_tfrecord_path [PATH]`"
)
val_ds = prepare_dataset(
VALIDATION_TF_RECORD_PATHS,
data_shape=DATA_SHAPE,
output_shape=OUTPUT_SHAPE,
epochs=epochs,
batch_size=VALIDATION_BATCH_SIZE,
shuffle=False,
augment=False,
)
val_metric_name = "val_accuracy"
model = build_and_compile_model(
n_classes, learning_rate, global_batch_size=batch_size
)
def scheduler(epoch, lr):
# 20 epochs at 0.1, 10 at 0.01 and 5 0.001
# The epoch number here is Keras's which is different from actual epoch number
epoch = epoch // KERAS_EPOCH_MULTIPLIER
if epoch in range(20):
return 0.1
elif epoch in range(20, 30):
return 0.01
else:
return 0.001
callbacks = {
"latest": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/latest", verbose=1
),
"best": tf.keras.callbacks.ModelCheckpoint(
f"{checkpoint_path}/best",
monitor=val_metric_name,
save_best_only=True,
mode="max",
verbose=1,
),
"tensorboard": tf.keras.callbacks.TensorBoard(
log_dir=f"{checkpoint_path}/logs", update_freq=15, profile_batch="10,50"
),
"lr": tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1),
"nan": tf.keras.callbacks.TerminateOnNaN(),
}
callbacks = add_backup_callback(callbacks, backup_dir=f"{checkpoint_path}/backup")
model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs * KERAS_EPOCH_MULTIPLIER,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_SAMPLES // VALIDATION_BATCH_SIZE,
callbacks=callbacks,
verbose=2,
)
if __name__ == "__main__":
train_and_evaluate()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment