From 22a73ed12068679b600ce6c5b65804dd8a8f6395 Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Fri, 26 Nov 2021 22:51:57 +0100
Subject: [PATCH] Add the test_em tests to test_gmm

---
 bob/learn/em/mixture/gmm.py   |   8 +-
 bob/learn/em/test/test_gmm.py | 258 ++++++++++++++++++++++++++++++++++
 2 files changed, 263 insertions(+), 3 deletions(-)

diff --git a/bob/learn/em/mixture/gmm.py b/bob/learn/em/mixture/gmm.py
index e3f72ff..d64b6f1 100644
--- a/bob/learn/em/mixture/gmm.py
+++ b/bob/learn/em/mixture/gmm.py
@@ -480,9 +480,9 @@ class GMMMachine(BaseEstimator):
                 ubm=ubm,
                 weights=hdf5["m_weights"][()],
             )
-            self.means = g_means
-            self.variances = g_variances
-            self.variance_thresholds = g_variance_thresholds
+            self.means = np.array(g_means)
+            self.variances = np.array(g_variances)
+            self.variance_thresholds = np.array(g_variance_thresholds)
         return self
 
     def save(self, hdf5):
@@ -717,11 +717,13 @@ class GMMMachine(BaseEstimator):
             # Note: Uses the stats from before m_step, leading to an additional m_step
             # (which is not bad because it will always converge)
             average_output = stats.log_likelihood / stats.t
+            logger.debug(f"average output = {average_output}")
 
             if step > 1:
                 convergence_value = abs(
                     (average_output_previous - average_output) / average_output_previous
                 )
+                logger.debug(f"convergence val = {convergence_value}")
 
                 # Terminates if converged (and likelihood computation is set)
                 if (
diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py
index 3f8dc37..40d535a 100644
--- a/bob/learn/em/test/test_gmm.py
+++ b/bob/learn/em/test/test_gmm.py
@@ -542,3 +542,261 @@ def test_map_transformer():
         sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float),
     )
     assert stats.is_similar_to(expected_stats)
+
+
+## Tests from `test_em.py`
+
+def loadGMM():
+    gmm = GMMMachine(2)
+
+    gmm.weights = bob.io.base.load(datafile("gmm.init_weights.hdf5", __name__, path="../data/"))
+    gmm.means = bob.io.base.load(datafile("gmm.init_means.hdf5", __name__, path="../data/"))
+    gmm.variances = bob.io.base.load(datafile("gmm.init_variances.hdf5", __name__, path="../data/"))
+
+    return gmm
+
+def equals(x, y, epsilon):
+    return (abs(x - y) < epsilon).all()
+
+def test_gmm_ML_1():
+    """Trains a GMMMachine with ML_GMMTrainer"""
+    ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__, path="../data/"))
+    gmm = loadGMM()
+
+    # test rng handling
+    gmm.convergence_threshold = 0.001
+    gmm.update_means = True
+    gmm.update_variances = True
+    gmm.update_weights = True
+    gmm.random_state = np.random.RandomState(seed=12345)
+    gmm = gmm.fit(ar)
+
+    gmm = loadGMM()
+    gmm.convergence_threshold = 0.001
+    gmm.update_means = True
+    gmm.update_variances = True
+    gmm.update_weights = True
+    gmm = gmm.fit(ar)
+
+    #config = HDF5File(datafile("gmm_ML.hdf5", __name__), "w")
+    #gmm.save(config)
+
+    gmm_ref = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data"), "r")) # TODO update the ref file(s)
+    gmm_ref_32bit_debug = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML_32bit_debug.hdf5", __name__, path="../data/"), "r"))
+    gmm_ref_32bit_release = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML_32bit_release.hdf5", __name__, path="../data/"), "r"))
+    assert (gmm == gmm_ref)  # or (gmm == gmm_ref_32bit_release) or (gmm == gmm_ref_32bit_debug)
+
+
+def test_gmm_ML_2():
+    """Trains a GMMMachine with ML_GMMTrainer; compares to an old reference"""
+    ar = bob.io.base.load(datafile("dataNormalized.hdf5", __name__, path="../data/"))
+
+    # Initialize GMMMachine
+    gmm = GMMMachine(5, 45)
+    gmm.means = bob.io.base.load(datafile("meansAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
+    gmm.variances = bob.io.base.load(datafile("variancesAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
+    gmm.weights = np.exp(bob.io.base.load(datafile("weightsAfterKMeans.hdf5", __name__, path="../data/")).astype("float64"))
+
+    threshold = 0.001
+    gmm.variance_thresholds = threshold
+
+    # Initialize ML Trainer
+    gmm.mean_var_update_threshold = 0.001
+    gmm.max_fitting_steps = 26
+    gmm.convergence_threshold = 0.00001
+    gmm.update_means = True
+    gmm.update_variances = True
+    gmm.update_weights = True
+
+    # Run ML
+    gmm = gmm.fit(ar)
+    # Test results
+    # Load torch3vision reference
+    meansML_ref = bob.io.base.load(datafile("meansAfterML.hdf5", __name__, path="../data/"))
+    variancesML_ref = bob.io.base.load(datafile("variancesAfterML.hdf5", __name__, path="../data/"))
+    weightsML_ref = bob.io.base.load(datafile("weightsAfterML.hdf5", __name__, path="../data/"))
+
+    # Compare to current results
+    np.testing.assert_allclose(gmm.means, meansML_ref, rtol=3e-3)
+    np.testing.assert_allclose(gmm.variances, variancesML_ref, rtol=3e-3)
+    np.testing.assert_allclose(gmm.weights, weightsML_ref, rtol=1e-4)
+
+
+def test_gmm_ML_parallel():
+    """Trains a GMMMachine with ML_GMMTrainer; compares to an old reference"""
+
+    ar = da.array(bob.io.base.load(datafile("dataNormalized.hdf5", __name__, path="../data/")))
+
+    # Initialize GMMMachine
+    gmm = GMMMachine(5, 45)
+    gmm.means = bob.io.base.load(datafile("meansAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
+    gmm.variances = bob.io.base.load(datafile("variancesAfterKMeans.hdf5", __name__, path="../data/")).astype("float64")
+    gmm.weights = np.exp(bob.io.base.load(datafile("weightsAfterKMeans.hdf5", __name__, path="../data/")).astype("float64"))
+
+    threshold = 0.001
+    gmm.variance_thresholds = threshold
+
+    # Initialize ML Trainer
+    gmm.mean_var_update_threshold = 0.001
+    gmm.max_fitting_steps = 25
+    gmm.convergence_threshold = 0.00001
+    gmm.update_means = True
+    gmm.update_variances = True
+    gmm.update_weights = True
+
+    # Run ML
+    gmm.fit(ar)
+
+    # Test results
+    # Load torch3vision reference
+    meansML_ref = bob.io.base.load(datafile("meansAfterML.hdf5", __name__, path="../data/"))
+    variancesML_ref = bob.io.base.load(datafile("variancesAfterML.hdf5", __name__, path="../data/"))
+    weightsML_ref = bob.io.base.load(datafile("weightsAfterML.hdf5", __name__, path="../data/"))
+
+    # Compare to current results
+    np.testing.assert_allclose(gmm.means, meansML_ref, rtol=3e-3)
+    np.testing.assert_allclose(gmm.variances, variancesML_ref, rtol=3e-3)
+    np.testing.assert_allclose(gmm.weights, weightsML_ref, rtol=1e-4)
+
+
+
+def test_gmm_MAP_1():
+    """Train a GMMMachine with MAP_GMMTrainer"""
+    ar = bob.io.base.load(datafile("faithful.torch3_f64.hdf5", __name__, path="../data/"))
+
+    # test with rng
+    gmmprior = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"), "r"))
+    gmm = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"), "r"), ubm = gmmprior)
+    gmm.update_means = True
+    gmm.update_variances = False
+    gmm.update_weights = False
+    rng = np.random.RandomState(seed=12345)
+    gmm.random_state = rng
+    gmm = gmm.fit(ar)
+
+    gmmprior = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"), "r"))
+    gmm = GMMMachine.from_hdf5(HDF5File(datafile("gmm_ML.hdf5", __name__, path="../data/"), "r"), ubm = gmmprior)
+    gmm.update_means = True
+    gmm.update_variances = False
+    gmm.update_weights = False
+
+    gmm = gmm.fit(ar)
+
+    gmm_ref = GMMMachine.from_hdf5(HDF5File(datafile("gmm_MAP.hdf5", __name__, path="../data/"), "r"))
+
+    np.testing.assert_allclose(gmm.means, gmm_ref.means, rtol=1e-3)
+    np.testing.assert_allclose(gmm.variances, gmm_ref.variances, rtol=1e-3)
+    np.testing.assert_allclose(gmm.weights, gmm_ref.weights, rtol=1e-3)
+
+
+def test_gmm_MAP_2():
+    """Train a GMMMachine with MAP_GMMTrainer and compare with matlab reference"""
+
+    data = bob.io.base.load(datafile("data.hdf5", __name__, path="../data/"))
+    data = data.reshape((-1, 1))  # make a 2D array out of it
+    means = bob.io.base.load(datafile("means.hdf5", __name__, path="../data/"))
+    variances = bob.io.base.load(datafile("variances.hdf5", __name__, path="../data/"))
+    weights = bob.io.base.load(datafile("weights.hdf5", __name__, path="../data/"))
+
+    gmm = GMMMachine(n_gaussians=2)
+    gmm.means = means
+    gmm.variances = variances
+    gmm.weights = weights
+
+    gmm_adapted = GMMMachine(
+        n_gaussians=2,
+        trainer="map",
+        ubm=gmm,
+        max_fitting_steps=1,
+        update_means=True,
+        update_variances=False,
+        update_weights=False,
+        mean_var_update_threshold=0.,
+    )
+    gmm_adapted.means = means
+    gmm_adapted.variances = variances
+    gmm_adapted.weights = weights
+
+    gmm = gmm.fit(data)
+
+    gmm_adapted = gmm_adapted.fit(data)
+
+    new_means = bob.io.base.load(datafile("new_adapted_mean.hdf5", __name__, path="../data/"))
+
+    # Compare to matlab reference
+    np.testing.assert_allclose(new_means[0,:], gmm_adapted.means[:,0], rtol=1e-4)
+    np.testing.assert_allclose(new_means[1,:], gmm_adapted.means[:,1], rtol=1e-4)
+
+
+def test_gmm_MAP_3():
+    """Train a GMMMachine with MAP_GMMTrainer; compares to old reference"""
+    ar = bob.io.base.load(datafile("dataforMAP.hdf5", __name__, path="../data/"))
+
+    # Initialize GMMMachine
+    n_gaussians = 5
+    prior_gmm = GMMMachine(n_gaussians)
+    prior_gmm.means = bob.io.base.load(datafile("meansAfterML.hdf5", __name__, path="../data/"))
+    prior_gmm.variances = bob.io.base.load(datafile("variancesAfterML.hdf5", __name__, path="../data/"))
+    prior_gmm.weights = bob.io.base.load(datafile("weightsAfterML.hdf5", __name__, path="../data/"))
+
+    threshold = 0.001
+    prior_gmm.variance_thresholds = threshold
+
+    # Initialize MAP Trainer
+    accuracy = 0.00001
+
+    gmm = GMMMachine(
+        n_gaussians,
+        trainer="map",
+        ubm=prior_gmm,
+        convergence_threshold=threshold,
+        max_fitting_steps=1,
+        update_means=True,
+        update_variances=False,
+        update_weights=False,
+        mean_var_update_threshold=accuracy
+    )
+    gmm.variance_thresholds = threshold
+
+    # Train
+    gmm = gmm.fit(ar)
+
+    # Test results
+    # Load torch3vision reference
+    meansMAP_ref = bob.io.base.load(datafile("meansAfterMAP.hdf5", __name__, path="../data/"))
+    variancesMAP_ref = bob.io.base.load(datafile("variancesAfterMAP.hdf5", __name__, path="../data/"))
+    weightsMAP_ref = bob.io.base.load(datafile("weightsAfterMAP.hdf5", __name__, path="../data/"))
+
+    # Compare to current results
+    # Gaps are quite large. This might be explained by the fact that there is no
+    # adaptation of a given Gaussian in torch3 when the corresponding responsibilities
+    # are below the responsibilities threshold
+    np.testing.assert_allclose(gmm.means, meansMAP_ref, rtol=2e-1)
+    np.testing.assert_allclose(gmm.variances, variancesMAP_ref, rtol=1e-4)
+    np.testing.assert_allclose(gmm.weights, weightsMAP_ref, rtol=1e-4)
+
+
+def test_gmm_test():
+
+    # Tests a GMMMachine by computing scores against a model and compare to
+    # an old reference
+
+    ar = bob.io.base.load(datafile("dataforMAP.hdf5", __name__, path="../data/"))
+
+    # Initialize GMMMachine
+    n_gaussians = 5
+    gmm = GMMMachine(n_gaussians)
+    gmm.means = bob.io.base.load(datafile("meansAfterML.hdf5", __name__, path="../data/"))
+    gmm.variances = bob.io.base.load(datafile("variancesAfterML.hdf5", __name__, path="../data/"))
+    gmm.weights = bob.io.base.load(datafile("weightsAfterML.hdf5", __name__, path="../data/"))
+
+    threshold = 0.001
+    gmm.variance_thresholds = threshold
+
+    # Test against the model
+    score_mean_ref = -1.50379e+06
+    score = gmm.log_likelihood(ar).sum()
+    score /= len(ar)
+
+    # Compare current results to torch3vision
+    assert abs(score-score_mean_ref)/score_mean_ref<1e-4
-- 
GitLab