From c3ad33f3a9acf5b8004ff27938eb178a43ee7a11 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Tue, 12 Apr 2022 14:18:41 +0200
Subject: [PATCH] [factor_analysis] test fit without prior

---
 bob/learn/em/factor_analysis.py           |  25 +--
 bob/learn/em/test/test_factor_analysis.py | 178 ++++++++++++----------
 2 files changed, 103 insertions(+), 100 deletions(-)

diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 3a6b1de..1f2e8c3 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -4,7 +4,6 @@
 
 import logging
 
-import dask
 import numpy as np
 
 from sklearn.base import BaseEstimator
@@ -183,7 +182,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return len(unique_labels(y))
 
-    def initialize(self, X, y):
+    def initialize(self, X):
         """
         Accumulating 0th and 1st order statistics. Trains the UBM if needed.
 
@@ -211,17 +210,12 @@ class FactorAnalysisBase(BaseEstimator):
         # Train the UBM if not already trained
         if self.ubm._means is None:
             logger.info(f"FA: Training the UBM with {self.ubm}.")
-            self.ubm.fit(np.vstack(X))  # GMMMachine.fit takes non-labeled data
-
-        logger.info("FA: Projection of training data on the UBM.")
-        ubm_projected_X = [dask.delayed(self.ubm.transform(xx)) for xx in X]
+            self.ubm.fit(X)  # GMMMachine.fit takes non-labeled data
 
         # Initializing the state matrix
         if not hasattr(self, "_U") or not hasattr(self, "_D"):
             self.create_UVD()
 
-        self.initialize_using_stats(ubm_projected_X, y)
-
     def initialize_using_stats(self, ubm_projected_X, y):
         # Accumulating 0th and 1st order statistics
         # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
@@ -1159,6 +1153,7 @@ class FactorAnalysisBase(BaseEstimator):
         return self.score_using_stats(model, self.ubm.transform(data))
 
     def fit(self, X, y):
+        self.initialize(X)
         stats = [self.ubm.transform(xx) for xx in X]
         return self.fit_using_stats(stats, y)
 
@@ -1214,20 +1209,6 @@ class ISVMachine(FactorAnalysisBase):
             **kwargs,
         )
 
-    def initialize(self, X, y):
-        """Initializes the ISV parameters and trains a UBM with `X` if needed.
-
-        If no UBM has been defined on init, it is trained with a new GMMMachine.
-
-        Parameters
-        ----------
-        X: np.ndarray of shape(n_clients, n_samples, n_features)
-            Input data for each client.
-        y: np.ndarray of shape(n_clients,)
-            Client labels.
-        """
-        return super().initialize(X, y)
-
     def e_step(self, X, y, n_acc, f_acc):
         """
         E-step of the EM algorithm
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index f5a85fd..ca1bd2a 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -443,7 +443,7 @@ def test_ISVMachine():
     np.testing.assert_allclose(score, score_ref, atol=eps)
 
 
-def test_ISV_fit():
+def test_ISV_JFA_fit():
     np.random.seed(10)
     data_class1 = np.random.normal(0, 0.5, (10, 3))
     data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
@@ -455,84 +455,106 @@ def test_ISV_fit():
     prior_gmm.means = np.vstack(
         (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
     )
-
     # All nice and round diagonal covariance
     prior_gmm.variances = np.ones((2, 3)) * 0.5
     prior_gmm.weights = np.array([0.3, 0.7])
 
-    # Finally doing the ISV training
-    isv = ISVMachine(
-        2,
-        ubm=prior_gmm,
-        relevance_factor=4,
-        em_iterations=50,
-    )
-    isv.fit(data, labels)
-
-    # Printing the session offset w.r.t each Gaussian component
-    U_ref = [
-        [-2.86662863e-02, 4.45865461e-04],
-        [-4.51712419e-03, 7.02577809e-05],
-        [7.91269855e-02, -1.23071365e-03],
-        [3.27129434e-02, -5.08805760e-04],
-        [9.17898003e-02, -1.42766668e-03],
-        [1.29496881e-01, -2.01414952e-03],
-    ]
-    # TODO(tiago): The reference used to be the values below but are different now
-    # U_ref = [
-    #     [-0.01, -0.027],
-    #     [-0.002, -0.004],
-    #     [0.028, 0.074],
-    #     [0.012, 0.03],
-    #     [0.033, 0.085],
-    #     [0.046, 0.12],
-    # ]
-    np.testing.assert_allclose(isv.U, U_ref, atol=1e-7)
-
-
-def test_JFA_fit():
-    np.random.seed(10)
-    data_class1 = np.random.normal(0, 0.5, (10, 3))
-    data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
-    data = np.concatenate([data_class1, data_class2], axis=0)
-    labels = [0] * 10 + [1] * 10
-
-    # Creating a fake prior with 2 gaussians
-    prior_gmm = GMMMachine(2)
-    prior_gmm.means = np.vstack(
-        (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
-    )
-
-    # All nice and round diagonal covariance
-    prior_gmm.variances = np.ones((2, 3)) * 0.5
-    prior_gmm.weights = np.array([0.3, 0.7])
-
-    # Finally doing the JFA training
-    jfa = JFAMachine(
-        2,
-        2,
-        ubm=prior_gmm,
-        relevance_factor=4,
-        em_iterations=50,
-    )
-    jfa.fit(data, labels)
-
-    # Printing the session offset w.r.t each Gaussian component
-    V_ref = [
-        [-0.00459188, 0.00463761],
-        [-0.06622346, 0.06688288],
-        [0.41800691, -0.4221692],
-        [0.40218688, -0.40619164],
-        [0.61849675, -0.6246554],
-        [0.57576069, -0.5814938],
-    ]
-    # TODO(tiago): The reference used to be the values below but are different now
-    # V_ref = [
-    #     [0.003, -0.006],
-    #     [0.041, -0.084],
-    #     [-0.261, 0.53],
-    #     [-0.252, 0.51],
-    #     [-0.387, 0.785],
-    #     [-0.36, 0.73],
-    # ]
-    np.testing.assert_allclose(jfa.V, V_ref, atol=1e-7)
+    for prior, machine_type, ref in [
+        (
+            None,
+            "isv",
+            [
+                [0.02619036, 0.07607595],
+                [-0.02570657, -0.07451667],
+                [-0.0430513, -0.12514552],
+                [-0.09729266, -0.28582205],
+                [-0.01035388, -0.03041718],
+                [0.0733034, 0.21534741],
+            ],
+        ),
+        (
+            prior_gmm,
+            "isv",
+            [
+                [-0.02361267, 0.0157274],
+                [-0.00372588, 0.00248165],
+                [0.06517179, -0.04340818],
+                [0.02694231, -0.01794513],
+                [0.07560949, -0.05036029],
+                [0.10668997, -0.07106169],
+            ],
+            # TODO(tiago): The reference used to be the values below but are different now
+            # [
+            #     [-0.01, -0.027],
+            #     [-0.002, -0.004],
+            #     [0.028, 0.074],
+            #     [0.012, 0.03],
+            #     [0.033, 0.085],
+            #     [0.046, 0.12],
+            # ]
+        ),
+        (
+            None,
+            "jfa",
+            [
+                [-1.72285693e-01, 1.47171193e-01],
+                [-1.08402014e-01, 9.25999920e-02],
+                [1.55349449e-02, -1.32703786e-02],
+                [2.13389657e-04, -1.82283334e-04],
+                [1.84127661e-05, -1.57286929e-05],
+                [-1.90492196e-04, 1.62723691e-04],
+            ],
+        ),
+        (
+            prior_gmm,
+            "jfa",
+            [
+                [6.54547662e-03, 1.98699266e-04],
+                [9.48510389e-02, 2.87936736e-03],
+                [-5.98879972e-01, -1.81800375e-02],
+                [-5.76350228e-01, -1.74961082e-02],
+                [-8.86302168e-01, -2.69052355e-02],
+                [-8.25011907e-01, -2.50446636e-02],
+            ],
+            # TODO(tiago): The reference used to be the values below but are different now
+            #   [[ 0.003 -0.006]
+            #    [ 0.041 -0.084]
+            #    [-0.261  0.53 ]
+            #    [-0.252  0.51 ]
+            #    [-0.387  0.785]
+            #    [-0.36   0.73 ]]
+        ),
+    ]:
+        ref = np.asarray(ref)
+        ubm_kwargs = dict(n_gaussians=2) if prior is None else None
+
+        # Doing the training
+        if machine_type == "isv":
+            machine = ISVMachine(
+                2,
+                ubm=prior,
+                relevance_factor=4,
+                em_iterations=50,
+                ubm_kwargs=ubm_kwargs,
+                seed=10,
+            )
+            test_attr = "U"
+        else:
+            machine = JFAMachine(
+                2,
+                2,
+                ubm=prior,
+                relevance_factor=4,
+                em_iterations=50,
+                ubm_kwargs=ubm_kwargs,
+                seed=10,
+            )
+            test_attr = "V"
+        machine.fit(data, labels)
+        arr = getattr(machine, test_attr)
+        np.testing.assert_allclose(
+            arr,
+            ref,
+            atol=1e-7,
+            err_msg=f"Test failed with prior={prior} and machine_type={machine_type}",
+        )
-- 
GitLab