From 3ea8cb227f80b2365a93599af2470e58f25beec8 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Tue, 20 Oct 2020 18:13:13 +0200 Subject: [PATCH] Tests with independence Implemented MINE Implemented MINE --- bob/learn/tensorflow/losses/__init__.py | 2 +- bob/learn/tensorflow/metrics/__init__.py | 2 +- bob/learn/tensorflow/models/__init__.py | 2 + bob/learn/tensorflow/models/mine.py | 68 ++++++++++++++++++++++++ bob/learn/tensorflow/tests/test_mine.py | 33 ++++++++++++ requirements.txt | 13 +---- 6 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 bob/learn/tensorflow/models/mine.py create mode 100644 bob/learn/tensorflow/tests/test_mine.py diff --git a/bob/learn/tensorflow/losses/__init__.py b/bob/learn/tensorflow/losses/__init__.py index 2bfcbd15..65cdcab2 100644 --- a/bob/learn/tensorflow/losses/__init__.py +++ b/bob/learn/tensorflow/losses/__init__.py @@ -18,6 +18,6 @@ def __appropriate__(*args): __appropriate__( CenterLoss, - CenterLossLayer, + CenterLossLayer ) __all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/learn/tensorflow/metrics/__init__.py b/bob/learn/tensorflow/metrics/__init__.py index 55cec2bd..72ee7b5f 100644 --- a/bob/learn/tensorflow/metrics/__init__.py +++ b/bob/learn/tensorflow/metrics/__init__.py @@ -1,4 +1,4 @@ -from .embedding_accuracy import EmbeddingAccuracy +from .embedding_accuracy import EmbeddingAccuracy, predict_using_tensors # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): diff --git a/bob/learn/tensorflow/models/__init__.py b/bob/learn/tensorflow/models/__init__.py index d18ceb93..48804ec2 100644 --- a/bob/learn/tensorflow/models/__init__.py +++ b/bob/learn/tensorflow/models/__init__.py @@ -1,5 +1,6 @@ from .alexnet import AlexNet_simplified from .densenet import DenseNet +from .mine import MineModel # gets sphinx autodoc done right - don't remove it def __appropriate__(*args): @@ -20,5 +21,6 @@ def __appropriate__(*args): __appropriate__( AlexNet_simplified, DenseNet, + MineModel ) __all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/learn/tensorflow/models/mine.py b/bob/learn/tensorflow/models/mine.py new file mode 100644 index 00000000..4f766236 --- /dev/null +++ b/bob/learn/tensorflow/models/mine.py @@ -0,0 +1,68 @@ +""" +Implements the MINE loss from the paper: + +Mutual Information Neural Estimation (https://arxiv.org/pdf/1801.04062.pdf) + +""" + +import tensorflow as tf + +class MineModel(tf.keras.Model): + """ + + Parameters + ********** + + is_mine_f: bool + If true, will implement MINE-F (equation 6), otherwise will implement equation 5 + """ + + def __init__(self, is_mine_f=False, name="MINE", units=10, **kwargs): + super().__init__(name=name, **kwargs) + self.units = units + self.is_mine_f = is_mine_f + + self.transformer_x = tf.keras.layers.Dense(self.units) + self.transformer_z = tf.keras.layers.Dense(self.units) + self.transformer_xz = tf.keras.layers.Dense(self.units) + self.transformer_output = tf.keras.layers.Dense(1) + + def call(self, inputs): + def compute(x, z): + h1_x = self.transformer_x(x) + h1_z = self.transformer_z(z) + h1 = tf.keras.layers.ReLU()(h1_x + h1_z) + h2 = self.transformer_output(tf.keras.layers.ReLU()(self.transformer_xz(h1))) + + return h2 + + def compute_lower_bound(x, z): + t_xz = compute(x,z) + z_shuffle = tf.random.shuffle(z) + t_x_z = compute(x, z_shuffle) + + if self.is_mine_f: + lb = -( + tf.reduce_mean(t_xz, axis=0) + - tf.reduce_mean(tf.math.exp(t_x_z-1)) + ) + else: + lb = -( + tf.reduce_mean(t_xz, axis=0) + - tf.math.log(tf.reduce_mean(tf.math.exp(t_x_z))) + ) + + self.add_loss(lb) + return -lb + + x = inputs[0] + z = inputs[1] + + return compute_lower_bound(x, z) + + + def get_config(self): + config = super().get_config() + config.update({"units": self.units}) + return config + diff --git a/bob/learn/tensorflow/tests/test_mine.py b/bob/learn/tensorflow/tests/test_mine.py new file mode 100644 index 00000000..89cf943b --- /dev/null +++ b/bob/learn/tensorflow/tests/test_mine.py @@ -0,0 +1,33 @@ +import numpy as np +import tensorflow as tf +from bob.learn.tensorflow.models import MineModel + +def run_mine(is_mine_f): + np.random.seed(10) + N = 20000 + d = 1 + EPOCHS = 100 + + X = np.sign(np.random.normal(0.,1.,[N, d])) + Z = X + np.random.normal(0.,np.sqrt(0.2),[N, d]) + + + from sklearn.feature_selection import mutual_info_regression + mi_numerical = mutual_info_regression(X.reshape(-1, 1), Z.ravel())[0] + + model = MineModel(is_mine_f=is_mine_f) + model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01)) + + callback = model.fit( + x=[X, Z], epochs=EPOCHS, verbose=1, batch_size=100 + ) + mine = -np.array(callback.history["loss"])[-1] + + np.allclose(mine,mi_numerical, atol=0.01) + + +def test_mine(): + run_mine(False) + +def test_mine_f(): + run_mine(True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 317ae0d9..a6cac773 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,16 +2,5 @@ setuptools bob.extension bob.io.base bob.io.image -bob.learn.activation -bob.learn.em -bob.learn.linear bob.ip.base -bob.math -bob.measure -bob.sp -bob.db.mnist -bob.db.atnt -gridtk -numpy -scipy -click >= 7 +bob.measure \ No newline at end of file -- GitLab