From fa765388f6637c421168f1664291a314adc65225 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 16:04:18 +0100
Subject: [PATCH] add center loss, mmd loss, and pairwise confusion loss

---
 bob/learn/tensorflow/loss/__init__.py         |  3 ++
 bob/learn/tensorflow/loss/center_loss.py      | 39 +++++++++++++++++++
 bob/learn/tensorflow/loss/mmd.py              | 27 +++++++++++++
 .../tensorflow/loss/pairwise_confusion.py     | 16 ++++++++
 4 files changed, 85 insertions(+)
 create mode 100644 bob/learn/tensorflow/loss/center_loss.py
 create mode 100644 bob/learn/tensorflow/loss/mmd.py
 create mode 100644 bob/learn/tensorflow/loss/pairwise_confusion.py

diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py
index eab22bb4..7d3937ff 100644
--- a/bob/learn/tensorflow/loss/__init__.py
+++ b/bob/learn/tensorflow/loss/__init__.py
@@ -4,6 +4,9 @@ from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
 from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
 from .vat import VATLoss
 from .pixel_wise import PixelWise
+from .center_loss import CenterLoss
+from .mmd import *
+from .pairwise_confusion import total_pairwise_confusion
 from .utils import *
 
 
diff --git a/bob/learn/tensorflow/loss/center_loss.py b/bob/learn/tensorflow/loss/center_loss.py
new file mode 100644
index 00000000..00494387
--- /dev/null
+++ b/bob/learn/tensorflow/loss/center_loss.py
@@ -0,0 +1,39 @@
+import tensorflow as tf
+
+# TODO(amir): replace parent class with tf.Module in tensorflow 1.14 and above.
+# * pass ``name`` to parent class
+# * replace get_variable with tf.Variable
+# * replace variable_scope with name_scope
+class CenterLoss:
+    """Center loss."""
+
+    def __init__(self, n_classes, n_features, alpha=0.9, name="center_loss", **kwargs):
+        super().__init__(**kwargs)
+        self.n_classes = n_classes
+        self.n_features = n_features
+        self.alpha = alpha
+        self.name = name
+        with tf.variable_scope(self.name):
+            self.centers = tf.get_variable(
+                "centers",
+                [n_classes, n_features],
+                dtype=tf.float32,
+                initializer=tf.constant_initializer(0.),
+                trainable=False,
+            )
+
+    def __call__(self, sparse_labels, prelogits):
+        with tf.name_scope(self.name):
+            centers_batch = tf.gather(self.centers, sparse_labels)
+            diff = (1 - self.alpha) * (centers_batch - prelogits)
+            self.centers_update_op = tf.scatter_sub(self.centers, sparse_labels, diff)
+            center_loss = tf.reduce_mean(tf.square(prelogits - centers_batch))
+        tf.summary.scalar("loss_center", center_loss)
+        # Add histogram for all centers
+        for i in range(self.n_classes):
+            tf.summary.histogram(f"center_{i}", self.centers[i])
+        return center_loss
+
+    @property
+    def update_ops(self):
+        return [self.centers_update_op]
diff --git a/bob/learn/tensorflow/loss/mmd.py b/bob/learn/tensorflow/loss/mmd.py
new file mode 100644
index 00000000..bd7df3e5
--- /dev/null
+++ b/bob/learn/tensorflow/loss/mmd.py
@@ -0,0 +1,27 @@
+import tensorflow as tf
+
+
+def compute_kernel(x, y):
+    x_size = tf.shape(x)[0]
+    y_size = tf.shape(y)[0]
+    dim = tf.shape(x)[1]
+    tiled_x = tf.tile(
+        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
+    )
+    tiled_y = tf.tile(
+        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
+    )
+    return tf.exp(
+        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
+    )
+
+
+def mmd(x, y):
+    x_kernel = compute_kernel(x, x)
+    y_kernel = compute_kernel(y, y)
+    xy_kernel = compute_kernel(x, y)
+    return (
+        tf.reduce_mean(x_kernel)
+        + tf.reduce_mean(y_kernel)
+        - 2 * tf.reduce_mean(xy_kernel)
+    )
diff --git a/bob/learn/tensorflow/loss/pairwise_confusion.py b/bob/learn/tensorflow/loss/pairwise_confusion.py
new file mode 100644
index 00000000..155b1a29
--- /dev/null
+++ b/bob/learn/tensorflow/loss/pairwise_confusion.py
@@ -0,0 +1,16 @@
+import tensorflow as tf
+from ..utils import pdist_safe, upper_triangle
+
+def total_pairwise_confusion(prelogits, name=None):
+    """Total Pairwise Confusion Loss
+
+        [1]X. Tu et al., “Learning Generalizable and Identity-Discriminative
+        Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019.
+    """
+    # compute L2 norm between all prelogits and sum them.
+    with tf.name_scope(name, default_name="total_pairwise_confusion"):
+        prelogits = tf.reshape(prelogits, (tf.shape(prelogits)[0], -1))
+        loss_tpc = tf.reduce_mean(upper_triangle(pdist_safe(prelogits)))
+
+    tf.summary.scalar("loss_tpc", loss_tpc)
+    return loss_tpc
-- 
GitLab