mine.py 1.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""
Implements the MINE loss from the paper:

Mutual Information Neural Estimation (https://arxiv.org/pdf/1801.04062.pdf)

"""

import tensorflow as tf

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
class MineModel(tf.keras.Model):
    """

    Parameters
    **********

      is_mine_f: bool
         If true, will implement MINE-F (equation 6), otherwise will implement equation 5
    """

    def __init__(self, is_mine_f=False, name="MINE", units=10, **kwargs):
        super().__init__(name=name, **kwargs)
        self.units = units
        self.is_mine_f = is_mine_f
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
25

26
27
28
29
30
31
32
33
34
35
        self.transformer_x = tf.keras.layers.Dense(self.units)
        self.transformer_z = tf.keras.layers.Dense(self.units)
        self.transformer_xz = tf.keras.layers.Dense(self.units)
        self.transformer_output = tf.keras.layers.Dense(1)

    def call(self, inputs):
        def compute(x, z):
            h1_x = self.transformer_x(x)
            h1_z = self.transformer_z(z)
            h1 = tf.keras.layers.ReLU()(h1_x + h1_z)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
36
37
38
            h2 = self.transformer_output(
                tf.keras.layers.ReLU()(self.transformer_xz(h1))
            )
39
40
41
42

            return h2

        def compute_lower_bound(x, z):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
43
            t_xz = compute(x, z)
44
45
46
47
48
49
            z_shuffle = tf.random.shuffle(z)
            t_x_z = compute(x, z_shuffle)

            if self.is_mine_f:
                lb = -(
                    tf.reduce_mean(t_xz, axis=0)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
50
                    - tf.reduce_mean(tf.math.exp(t_x_z - 1))
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
                )
            else:
                lb = -(
                    tf.reduce_mean(t_xz, axis=0)
                    - tf.math.log(tf.reduce_mean(tf.math.exp(t_x_z)))
                )

            self.add_loss(lb)
            return -lb

        x = inputs[0]
        z = inputs[1]

        return compute_lower_bound(x, z)

    def get_config(self):
        config = super().get_config()
        config.update({"units": self.units})
        return config