From fa765388f6637c421168f1664291a314adc65225 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 7 Feb 2020 16:04:18 +0100 Subject: [PATCH] add center loss, mmd loss, and pairwise confusion loss --- bob/learn/tensorflow/loss/__init__.py | 3 ++ bob/learn/tensorflow/loss/center_loss.py | 39 +++++++++++++++++++ bob/learn/tensorflow/loss/mmd.py | 27 +++++++++++++ .../tensorflow/loss/pairwise_confusion.py | 16 ++++++++ 4 files changed, 85 insertions(+) create mode 100644 bob/learn/tensorflow/loss/center_loss.py create mode 100644 bob/learn/tensorflow/loss/mmd.py create mode 100644 bob/learn/tensorflow/loss/pairwise_confusion.py diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py index eab22bb4..7d3937ff 100644 --- a/bob/learn/tensorflow/loss/__init__.py +++ b/bob/learn/tensorflow/loss/__init__.py @@ -4,6 +4,9 @@ from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss from .vat import VATLoss from .pixel_wise import PixelWise +from .center_loss import CenterLoss +from .mmd import * +from .pairwise_confusion import total_pairwise_confusion from .utils import * diff --git a/bob/learn/tensorflow/loss/center_loss.py b/bob/learn/tensorflow/loss/center_loss.py new file mode 100644 index 00000000..00494387 --- /dev/null +++ b/bob/learn/tensorflow/loss/center_loss.py @@ -0,0 +1,39 @@ +import tensorflow as tf + +# TODO(amir): replace parent class with tf.Module in tensorflow 1.14 and above. +# * pass ``name`` to parent class +# * replace get_variable with tf.Variable +# * replace variable_scope with name_scope +class CenterLoss: + """Center loss.""" + + def __init__(self, n_classes, n_features, alpha=0.9, name="center_loss", **kwargs): + super().__init__(**kwargs) + self.n_classes = n_classes + self.n_features = n_features + self.alpha = alpha + self.name = name + with tf.variable_scope(self.name): + self.centers = tf.get_variable( + "centers", + [n_classes, n_features], + dtype=tf.float32, + initializer=tf.constant_initializer(0.), + trainable=False, + ) + + def __call__(self, sparse_labels, prelogits): + with tf.name_scope(self.name): + centers_batch = tf.gather(self.centers, sparse_labels) + diff = (1 - self.alpha) * (centers_batch - prelogits) + self.centers_update_op = tf.scatter_sub(self.centers, sparse_labels, diff) + center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch)) + tf.summary.scalar("loss_center", center_loss) + # Add histogram for all centers + for i in range(self.n_classes): + tf.summary.histogram(f"center_{i}", self.centers[i]) + return center_loss + + @property + def update_ops(self): + return [self.centers_update_op] diff --git a/bob/learn/tensorflow/loss/mmd.py b/bob/learn/tensorflow/loss/mmd.py new file mode 100644 index 00000000..bd7df3e5 --- /dev/null +++ b/bob/learn/tensorflow/loss/mmd.py @@ -0,0 +1,27 @@ +import tensorflow as tf + + +def compute_kernel(x, y): + x_size = tf.shape(x)[0] + y_size = tf.shape(y)[0] + dim = tf.shape(x)[1] + tiled_x = tf.tile( + tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1]) + ) + tiled_y = tf.tile( + tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1]) + ) + return tf.exp( + -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32) + ) + + +def mmd(x, y): + x_kernel = compute_kernel(x, x) + y_kernel = compute_kernel(y, y) + xy_kernel = compute_kernel(x, y) + return ( + tf.reduce_mean(x_kernel) + + tf.reduce_mean(y_kernel) + - 2 * tf.reduce_mean(xy_kernel) + ) diff --git a/bob/learn/tensorflow/loss/pairwise_confusion.py b/bob/learn/tensorflow/loss/pairwise_confusion.py new file mode 100644 index 00000000..155b1a29 --- /dev/null +++ b/bob/learn/tensorflow/loss/pairwise_confusion.py @@ -0,0 +1,16 @@ +import tensorflow as tf +from ..utils import pdist_safe, upper_triangle + +def total_pairwise_confusion(prelogits, name=None): + """Total Pairwise Confusion Loss + + [1]X. Tu et al., “Learning Generalizable and Identity-Discriminative + Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019. + """ + # compute L2 norm between all prelogits and sum them. + with tf.name_scope(name, default_name="total_pairwise_confusion"): + prelogits = tf.reshape(prelogits, (tf.shape(prelogits)[0], -1)) + loss_tpc = tf.reduce_mean(upper_triangle(pdist_safe(prelogits))) + + tf.summary.scalar("loss_tpc", loss_tpc) + return loss_tpc -- GitLab