Skip to content
Snippets Groups Projects
Commit 4523531b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

add center loss, mmd loss, and pairwise confusion loss

parent ea28b602
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,9 @@ from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss
from .pixel_wise import PixelWise
from .center_loss import CenterLoss
from .mmd import *
from .pairwise_confusion import total_pairwise_confusion
from .utils import *
......
import tensorflow as tf
# TODO(amir): replace parent class with tf.Module in tensorflow 1.14 and above.
# * pass ``name`` to parent class
# * replace get_variable with tf.Variable
# * replace variable_scope with name_scope
class CenterLoss:
"""Center loss."""
def __init__(self, n_classes, n_features, alpha=0.9, name="center_loss", **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes
self.n_features = n_features
self.alpha = alpha
self.name = name
with tf.variable_scope(self.name):
self.centers = tf.get_variable(
"centers",
[n_classes, n_features],
dtype=tf.float32,
initializer=tf.constant_initializer(0.),
trainable=False,
)
def __call__(self, sparse_labels, prelogits):
with tf.name_scope(self.name):
centers_batch = tf.gather(self.centers, sparse_labels)
diff = (1 - self.alpha) * (centers_batch - prelogits)
self.centers_update_op = tf.scatter_sub(self.centers, sparse_labels, diff)
center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
tf.summary.scalar("loss_center", center_loss)
# Add histogram for all centers
for i in range(self.n_classes):
tf.summary.histogram(f"center_{i}", self.centers[i])
return center_loss
@property
def update_ops(self):
return [self.centers_update_op]
import tensorflow as tf
def compute_kernel(x, y):
x_size = tf.shape(x)[0]
y_size = tf.shape(y)[0]
dim = tf.shape(x)[1]
tiled_x = tf.tile(
tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
)
tiled_y = tf.tile(
tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
)
return tf.exp(
-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
)
def mmd(x, y):
x_kernel = compute_kernel(x, x)
y_kernel = compute_kernel(y, y)
xy_kernel = compute_kernel(x, y)
return (
tf.reduce_mean(x_kernel)
+ tf.reduce_mean(y_kernel)
- 2 * tf.reduce_mean(xy_kernel)
)
import tensorflow as tf
from ..utils import pdist_safe, upper_triangle
def total_pairwise_confusion(prelogits, name=None):
"""Total Pairwise Confusion Loss
[1]X. Tu et al., “Learning Generalizable and Identity-Discriminative
Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019.
"""
# compute L2 norm between all prelogits and sum them.
with tf.name_scope(name, default_name="total_pairwise_confusion"):
prelogits = tf.reshape(prelogits, (tf.shape(prelogits)[0], -1))
loss_tpc = tf.reduce_mean(upper_triangle(pdist_safe(prelogits)))
tf.summary.scalar("loss_tpc", loss_tpc)
return loss_tpc
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment