mmd.py 860 Bytes
Newer Older
1 2 3 4
import tensorflow as tf


def compute_kernel(x, y):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
5 6
    """Gaussian kernel.
    """
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(
        tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1])
    )
    tiled_y = tf.tile(
        tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1])
    )
    return tf.exp(
        -tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32)
    )


def mmd(x, y):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
22 23 24
    """Maximum Mean Discrepancy with Gaussian kernel.
    See: https://stats.stackexchange.com/a/276618/49433
    """
25 26 27 28 29 30 31 32
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return (
        tf.reduce_mean(x_kernel)
        + tf.reduce_mean(y_kernel)
        - 2 * tf.reduce_mean(xy_kernel)
    )