Commit fa765388 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

add center loss, mmd loss, and pairwise confusion loss

parent e155ed50
......@@ -4,6 +4,9 @@ from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss
from .pixel_wise import PixelWise
from .center_loss import CenterLoss
from .mmd import *
from .pairwise_confusion import total_pairwise_confusion
from .utils import *
import tensorflow as tf
# TODO(amir): replace parent class with tf.Module in tensorflow 1.14 and above.
# * pass ``name`` to parent class
# * replace get_variable with tf.Variable
# * replace variable_scope with name_scope
class CenterLoss:
"""Center loss."""
def __init__(self, n_classes, n_features, alpha=0.9, name="center_loss", **kwargs):
self.n_classes = n_classes
self.n_features = n_features
self.alpha = alpha = name
with tf.variable_scope(
self.centers = tf.get_variable(
[n_classes, n_features],
def __call__(self, sparse_labels, prelogits):
with tf.name_scope(
centers_batch = tf.gather(self.centers, sparse_labels)
diff = (1 - self.alpha) * (centers_batch - prelogits)
self.centers_update_op = tf.scatter_sub(self.centers, sparse_labels, diff)
center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
tf.summary.scalar("loss_center", center_loss)
# Add histogram for all centers
for i in range(self.n_classes):
tf.summary.histogram(f"center_{i}", self.centers[i])
return center_loss
def update_ops(self):
return [self.centers_update_op]
import tensorflow as tf
def compute_kernel(x, y):
x_size = tf.shape(x)[0]
y_size = tf.shape(y)[0]
dim = tf.shape(x)[1]
tiled_x = tf.tile(
tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
tiled_y = tf.tile(
tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
return tf.exp(
-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
def mmd(x, y):
x_kernel = compute_kernel(x, x)
y_kernel = compute_kernel(y, y)
xy_kernel = compute_kernel(x, y)
return (
+ tf.reduce_mean(y_kernel)
- 2 * tf.reduce_mean(xy_kernel)
import tensorflow as tf
from ..utils import pdist_safe, upper_triangle
def total_pairwise_confusion(prelogits, name=None):
"""Total Pairwise Confusion Loss
[1]X. Tu et al., “Learning Generalizable and Identity-Discriminative
Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019.
# compute L2 norm between all prelogits and sum them.
with tf.name_scope(name, default_name="total_pairwise_confusion"):
prelogits = tf.reshape(prelogits, (tf.shape(prelogits)[0], -1))
loss_tpc = tf.reduce_mean(upper_triangle(pdist_safe(prelogits)))
tf.summary.scalar("loss_tpc", loss_tpc)
return loss_tpc
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment