pairwise_confusion.py 678 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
import tensorflow as tf
from ..utils import pdist_safe, upper_triangle

def total_pairwise_confusion(prelogits, name=None):
    """Total Pairwise Confusion Loss

        [1]X. Tu et al., “Learning Generalizable and Identity-Discriminative
        Representations for Face Anti-Spoofing,” arXiv preprint arXiv:1901.05602, 2019.
    """
    # compute L2 norm between all prelogits and sum them.
    with tf.name_scope(name, default_name="total_pairwise_confusion"):
        prelogits = tf.reshape(prelogits, (tf.shape(prelogits)[0], -1))
        loss_tpc = tf.reduce_mean(upper_triangle(pdist_safe(prelogits)))

    tf.summary.scalar("loss_tpc", loss_tpc)
    return loss_tpc