Skip to content
Snippets Groups Projects
Commit babec5d4 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

ArcFace/SphereFace Loss

Apply suggestion to bob/learn/tensorflow/layers.py

Apply suggestion to bob/learn/tensorflow/layers.py

Rewiring Arcface to make it simpler

Rewiring Arcface to make it simpler

Few changes

Removing test

Removing test

Added test case
parent aaf932d2
Branches
Tags
1 merge request!93ArcFace/SphereFace Loss
Pipeline #47379 passed
import numbers import numbers
import tensorflow as tf import tensorflow as tf
import math
def _check_input( def _check_input(
...@@ -160,3 +161,143 @@ def Normalize(mean, std=1.0, **kwargs): ...@@ -160,3 +161,143 @@ def Normalize(mean, std=1.0, **kwargs):
return tf.keras.layers.experimental.preprocessing.Rescaling( return tf.keras.layers.experimental.preprocessing.Rescaling(
scale=scale, offset=offset, **kwargs scale=scale, offset=offset, **kwargs
) )
class SphereFaceLayer(tf.keras.layers.Layer):
"""
Implements the SphereFace loss from equation (7) of `SphereFace: Deep Hypersphere Embedding for Face Recognition <https://arxiv.org/abs/1704.08063>`_
If the parameter `original` is set to `True` it will computes exactly what's written in eq (7): :math:`\\text{soft}(x_i) = \\frac{exp(||x_i||\\text{cos}(\\psi(\\theta_{yi})))}{exp(||x_i||\\text{cos}(\\psi(\\theta_{yi}))) + \sum_{j;j\\neq yi} exp(||x_i||\\text{cos}(\psi(\\theta_{j}))) }`.
Where :math:`\\psi(\\theta) = -1^k \\text{cos}(m\\theta)-2k`.
Parameters
----------
n_classes: int
Number of classes
m: float
Margin
"""
def __init__(self, n_classes=10, m=0.5):
super(SphereFaceLayer, self).__init__(name="sphere_face_logits")
self.n_classes = n_classes
self.m = m
def build(self, input_shape):
super(SphereFaceLayer, self).build(input_shape[0])
shape = [input_shape[-1], self.n_classes]
self.W = self.add_variable("W", shape=shape)
self.pi = tf.constant(math.pi)
def call(self, X, training=None):
# normalize feature
X = tf.nn.l2_normalize(X, axis=1)
W = tf.nn.l2_normalize(self.W, axis=0)
# cos between X and W
cos_yi = tf.matmul(X, W)
# cos(m \theta)
theta = tf.math.acos(cos_yi)
cos_theta_m = tf.math.cos(self.m * theta)
# ||x||
x_norm = tf.norm(X, axis=-1, keepdims=True)
# phi = -1**k * cos(m \theta) - 2k
k = self.m * (theta / self.pi)
phi = ((-(1 ** k)) * cos_theta_m) - 2 * k
logits = x_norm * phi
return logits
class ModifiedSoftMaxLayer(tf.keras.layers.Layer):
"""
Implements the modified logit from equation (5) of `SphereFace: Deep Hypersphere Embedding for Face Recognition <https://arxiv.org/pdf/1704.08063.pdf>`_
It basically transforms the regular logit function to :math:`||x_i||cos(\\theta_{yi})`, where :math:`\\theta_{yi}=||x_i||_2^2||W||_2^2`
Parameters
----------
n_classes: int
Number of classes for the new logit function
"""
def __init__(self, n_classes=10):
super(ModifiedSoftMaxLayer, self).__init__(name="modified_softmax_logits")
self.n_classes = n_classes
def build(self, input_shape):
super(ModifiedSoftMaxLayer, self).build(input_shape[0])
shape = [input_shape[-1], self.n_classes]
self.W = self.add_variable("W", shape=shape)
def call(self, X, training=None):
# normalize feature
W = tf.nn.l2_normalize(self.W, axis=0)
# cos between X and W
cos_yi = tf.nn.l2_normalize(X, axis=1) @ W
logits = tf.norm(X) * cos_yi
return logits
from tensorflow.keras.layers import (
BatchNormalization,
Dropout,
Dense,
Concatenate,
GlobalAvgPool2D,
)
def add_bottleneck(model, bottleneck_size=128, dropout_rate=0.2):
"""
Amend a bottleneck layer to a Keras Model
Parameters
----------
model:
Keras model
bottleneck_size: int
Size of the bottleneck
dropout_rate: float
Dropout rate
"""
if not isinstance(model, tf.keras.models.Sequential):
new_model = tf.keras.models.Sequential(model, name="bottleneck")
else:
new_model = model
new_model.add(GlobalAvgPool2D())
new_model.add(Dropout(dropout_rate, name="Dropout"))
new_model.add(Dense(128, use_bias=False, name="embeddings"))
new_model.add(BatchNormalization(axis=-1, scale=False, name="embeddings/BatchNorm"))
return new_model
def add_top(model, n_classes):
if not isinstance(model, tf.keras.models.Sequential):
new_model = tf.keras.models.Sequential(model, name="logits")
else:
new_model = model
new_model.add(Dense(n_classes, name="logits"))
return new_model
...@@ -3,7 +3,8 @@ from .densenet import DeepPixBiS ...@@ -3,7 +3,8 @@ from .densenet import DeepPixBiS
from .densenet import DenseNet from .densenet import DenseNet
from .densenet import densenet161 # noqa: F401 from .densenet import densenet161 # noqa: F401
from .mine import MineModel from .mine import MineModel
from .embedding_validation import EmbeddingValidation
from .arcface import ArcFaceLayer, ArcFaceLayer3Penalties, ArcFaceModel
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
def __appropriate__(*args): def __appropriate__(*args):
...@@ -21,5 +22,14 @@ def __appropriate__(*args): ...@@ -21,5 +22,14 @@ def __appropriate__(*args):
obj.__module__ = __name__ obj.__module__ = __name__
__appropriate__(AlexNet_simplified, DenseNet, DeepPixBiS, MineModel) __appropriate__(
AlexNet_simplified,
DenseNet,
DeepPixBiS,
MineModel,
ArcFaceLayer,
ArcFaceLayer3Penalties,
ArcFaceModel,
EmbeddingValidation,
)
__all__ = [_ for _ in dir() if not _.startswith("_")] __all__ = [_ for _ in dir() if not _.startswith("_")]
import tensorflow as tf
from .embedding_validation import EmbeddingValidation
from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
import math
class ArcFaceModel(EmbeddingValidation):
def train_step(self, data):
X, y = data
with tf.GradientTape() as tape:
logits, _ = self((X, y), training=True)
loss = self.compiled_loss(
y, logits, sample_weight=None, regularization_losses=self.losses
)
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
self.compiled_metrics.update_state(y, logits, sample_weight=None)
self.train_loss(loss)
return {m.name: m.result() for m in self.metrics + [self.train_loss]}
def test_step(self, data):
"""
Test Step
"""
images, labels = data
# No worries, labels not used in validation
_, embeddings = self((images, labels), training=False)
self.validation_acc(accuracy_from_embeddings(labels, embeddings))
return {m.name: m.result() for m in [self.validation_acc]}
class ArcFaceLayer(tf.keras.layers.Layer):
"""
Implements the ArcFace from equation (3) of `ArcFace: Additive Angular Margin Loss for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_
Defined as:
:math:`s(cos(\\theta_i) + m`
Parameters
----------
n_classes: int
Number of classes
m: float
Margin
s: int
Scale
"""
def __init__(self, n_classes=10, s=30, m=0.5):
super(ArcFaceLayer, self).__init__(name="arc_face_logits")
self.n_classes = n_classes
self.s = s
self.m = m
def build(self, input_shape):
super(ArcFaceLayer, self).build(input_shape[0])
shape = [input_shape[-1], self.n_classes]
self.W = self.add_variable("W", shape=shape)
self.cos_m = tf.identity(math.cos(self.m), name="cos_m")
self.sin_m = tf.identity(math.sin(self.m), name="sin_m")
self.th = tf.identity(math.cos(math.pi - self.m), name="th")
self.mm = tf.identity(math.sin(math.pi - self.m) * self.m)
def call(self, X, y, training=None):
# normalize feature
X = tf.nn.l2_normalize(X, axis=1)
W = tf.nn.l2_normalize(self.W, axis=0)
# cos between X and W
cos_yi = tf.matmul(X, W)
# sin_yi = tf.math.sqrt(1-cos_yi**2)
sin_yi = tf.clip_by_value(tf.math.sqrt(1 - cos_yi ** 2), 0, 1)
# cos(x+m) = cos(x)*cos(m) - sin(x)*sin(m)
cos_yi_m = cos_yi * self.cos_m - sin_yi * self.sin_m
cos_yi_m = tf.where(cos_yi > self.th, cos_yi_m, cos_yi - self.mm)
# Preparing the hot-output
one_hot = tf.one_hot(
tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
)
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits
return logits
class ArcFaceLayer3Penalties(tf.keras.layers.Layer):
"""
Implements the ArcFace loss from equation (4) of `ArcFace: Additive Angular Margin Loss for Deep Face Recognition <https://arxiv.org/abs/1801.07698>`_
Defined as:
:math:`s(cos(m_1\\theta_i + m_2) -m_3`
"""
def __init__(self, n_classes=10, s=30, m1=0.5, m2=0.5, m3=0.5):
super(ArcFaceLayer3Penalties, self).__init__(name="arc_face_logits")
self.n_classes = n_classes
self.s = s
self.m1 = m1
self.m2 = m2
self.m3 = m3
def build(self, input_shape):
super(ArcFaceLayer3Penalties, self).build(input_shape[0])
shape = [input_shape[-1], self.n_classes]
self.W = self.add_variable("W", shape=shape)
def call(self, X, y, training=None):
# normalize feature
X = tf.nn.l2_normalize(X, axis=1)
W = tf.nn.l2_normalize(self.W, axis=0)
# cos between X and W
cos_yi = tf.matmul(X, W)
# Getting the angle
theta = tf.math.acos(cos_yi)
cos_yi_m = tf.math.cos(self.m1 * theta + self.m2) - self.m3
# logits = self.s*cos_theta_m
# Preparing the hot-output
one_hot = tf.one_hot(
tf.cast(y, tf.int32), depth=self.n_classes, name="one_hot_mask"
)
logits = (one_hot * cos_yi_m) + ((1.0 - one_hot) * cos_yi)
logits = self.s * logits
return logits
import tensorflow as tf
from bob.learn.tensorflow.metrics.embedding_accuracy import accuracy_from_embeddings
class EmbeddingValidation(tf.keras.Model):
"""
Use this model if the validation step should validate the accuracy with respect to embeddings.
In this model, the `test_step` runs the function `bob.learn.tensorflow.metrics.embedding_accuracy.accuracy_from_embeddings`
"""
def compile(
self, **kwargs,
):
"""
Compile
"""
super().compile(**kwargs)
self.train_loss = tf.keras.metrics.Mean(name="accuracy")
self.validation_acc = tf.keras.metrics.Mean(name="accuracy")
def train_step(self, data):
"""
Train Step
"""
X, y = data
with tf.GradientTape() as tape:
logits, _ = self(X, training=True)
loss = self.loss(y, logits)
trainable_vars = self.trainable_variables
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
self.compiled_metrics.update_state(y, logits, sample_weight=None)
self.train_loss(loss)
return {m.name: m.result() for m in self.metrics + [self.train_loss]}
# self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# self.train_loss(loss)
# return {m.name: m.result() for m in [self.train_loss]}
def test_step(self, data):
"""
Test Step
"""
images, labels = data
logits, prelogits = self(images, training=False)
self.validation_acc(accuracy_from_embeddings(labels, prelogits))
return {m.name: m.result() for m in [self.validation_acc]}
from bob.learn.tensorflow.models import (
EmbeddingValidation,
ArcFaceLayer,
ArcFaceModel,
ArcFaceLayer3Penalties,
)
from bob.learn.tensorflow.layers import (
SphereFaceLayer,
ModifiedSoftMaxLayer,
)
import numpy as np
def test_arcface_layer():
layer = ArcFaceLayer()
np.random.seed(10)
X = np.random.rand(10, 50)
y = [np.random.randint(10) for i in range(10)]
assert layer(X, y).shape == (10, 10)
def test_arcface_layer_3p():
layer = ArcFaceLayer3Penalties()
np.random.seed(10)
X = np.random.rand(10, 50)
y = [np.random.randint(10) for i in range(10)]
assert layer(X, y).shape == (10, 10)
def test_sphereface():
layer = SphereFaceLayer()
np.random.seed(10)
X = np.random.rand(10, 10)
assert layer(X).shape == (10, 10)
def test_modsoftmax():
layer = ModifiedSoftMaxLayer()
np.random.seed(10)
X = np.random.rand(10, 10)
assert layer(X).shape == (10, 10)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment