From 5ce317d2f234b5eff6a2035253d2a123c1262b1c Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 11 Apr 2022 14:30:29 +0200
Subject: [PATCH] [factor_analysis] Still allow fit and init using gmm stats

---
 bob/learn/em/factor_analysis.py       | 41 +++++++++++++++++----------
 bob/learn/em/test/test_jfa.py         |  2 +-
 bob/learn/em/test/test_jfa_trainer.py | 22 +++++++-------
 3 files changed, 38 insertions(+), 27 deletions(-)

diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 18cdaa0..184c5b7 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -104,6 +104,9 @@ class FactorAnalysisBase(BaseEstimator):
 
         self.relevance_factor = relevance_factor
 
+        if ubm is not None:
+            self.create_UVD()
+
     @property
     def feature_dimension(self):
         """Get the UBM Dimension"""
@@ -216,6 +219,9 @@ class FactorAnalysisBase(BaseEstimator):
         if not hasattr(self, "_U") or not hasattr(self, "_D"):
             self.create_UVD()
 
+        self.initialize_using_stats(ubm_projected_X, y)
+
+    def initialize_using_stats(self, ubm_projected_X, y):
         # Accumulating 0th and 1st order statistics
         # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
         # 0th order stats
@@ -814,7 +820,7 @@ class FactorAnalysisBase(BaseEstimator):
                 np.zeros(
                     (
                         self.r_U,
-                        y.count(y_i),
+                        np.sum(y == y_i),
                     )
                 )
             )
@@ -1192,7 +1198,7 @@ class ISVMachine(FactorAnalysisBase):
         ubm=None,
         **gmm_kwargs,
     ):
-        super(ISVMachine, self).__init__(
+        super().__init__(
             r_U=r_U,
             relevance_factor=relevance_factor,
             em_iterations=em_iterations,
@@ -1213,7 +1219,7 @@ class ISVMachine(FactorAnalysisBase):
         y: np.ndarray of shape(n_clients,)
             Client labels.
         """
-        return super(ISVMachine, self).initialize(X, y)
+        return super().initialize(X, y)
 
     def e_step(self, X, y, n_acc, f_acc):
         """
@@ -1252,7 +1258,7 @@ class ISVMachine(FactorAnalysisBase):
 
         self.update_U(acc_U_A1, acc_U_A2)
 
-    def fit(self, X, y):
+    def fit_using_stats(self, X, y):
         """
         Trains the U matrix (session variability matrix)
 
@@ -1270,10 +1276,10 @@ class ISVMachine(FactorAnalysisBase):
 
         """
 
-        y = np.array(y).tolist() if not isinstance(y, list) else y
+        y = np.asarray(y)
 
         # TODO: Point of MAP-REDUCE
-        n_acc, f_acc = self.initialize(X, y)
+        n_acc, f_acc = self.initialize_using_stats(X, y)
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i)
             # TODO: Point of MAP-REDUCE
@@ -1412,20 +1418,25 @@ class JFAMachine(FactorAnalysisBase):
     """
 
     def __init__(
-        self, ubm, r_U, r_V, em_iterations=10, relevance_factor=4.0, seed=0
+        self,
+        ubm,
+        r_U,
+        r_V,
+        em_iterations=10,
+        relevance_factor=4.0,
+        seed=0,
+        **kwargs,
     ):
-        super(JFAMachine, self).__init__(
-            ubm,
+        super().__init__(
+            ubm=ubm,
             r_U=r_U,
             r_V=r_V,
             relevance_factor=relevance_factor,
             em_iterations=em_iterations,
             seed=seed,
+            **kwargs,
         )
 
-    def initialize(self, X, y):
-        return super(JFAMachine, self).initialize(X, y)
-
     def e_step_v(self, X, y, n_acc, f_acc):
         """
         ISV E-step for the V matrix.
@@ -1769,7 +1780,7 @@ class JFAMachine(FactorAnalysisBase):
         """
         return self.enroll([self.ubm.transform(X)], iterations)
 
-    def fit(self, X, y):
+    def fit_using_stats(self, X, y):
         """
         Trains the U matrix (session variability matrix)
 
@@ -1795,10 +1806,10 @@ class JFAMachine(FactorAnalysisBase):
         ):
             self.create_UVD()
 
-        y = np.array(y).tolist() if not isinstance(y, list) else y
+        y = np.asarray(y)
 
         # TODO: Point of MAP-REDUCE
-        n_acc, f_acc = self.initialize(X, y)
+        n_acc, f_acc = self.initialize_using_stats(X, y)
 
         # Updating V
         for i in range(self.em_iterations):
diff --git a/bob/learn/em/test/test_jfa.py b/bob/learn/em/test/test_jfa.py
index 8e20b4a..dc15b71 100644
--- a/bob/learn/em/test/test_jfa.py
+++ b/bob/learn/em/test/test_jfa.py
@@ -65,7 +65,7 @@ def test_ISVMachine():
     ubm.variances = np.array([[1, 2, 1], [2, 1, 2]], "float64")
 
     # Creates a ISVMachine
-    isv_machine = ISVMachine(ubm, r_U=2, em_iterations=10)
+    isv_machine = ISVMachine(ubm=ubm, r_U=2, em_iterations=10)
     isv_machine.U = np.array(
         [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64"
     )
diff --git a/bob/learn/em/test/test_jfa_trainer.py b/bob/learn/em/test/test_jfa_trainer.py
index cc3d6ad..e38beec 100644
--- a/bob/learn/em/test/test_jfa_trainer.py
+++ b/bob/learn/em/test/test_jfa_trainer.py
@@ -126,7 +126,7 @@ def test_JFATrainAndEnrol():
     it.U = copy.deepcopy(M_u)
     it.V = copy.deepcopy(M_v)
     it.D = copy.deepcopy(M_d)
-    it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
 
     v_ref = np.array(
         [
@@ -225,7 +225,7 @@ def test_JFATrainAndEnrolWithNumpy():
     it.U = copy.deepcopy(M_u)
     it.V = copy.deepcopy(M_v)
     it.D = copy.deepcopy(M_d)
-    it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
 
     v_ref = np.array(
         [
@@ -337,14 +337,14 @@ def test_ISVTrainAndEnrol():
     ubm.variances = UBM_VAR.reshape((2, 3))
 
     it = ISVMachine(
-        ubm,
+        ubm=ubm,
         r_U=2,
         relevance_factor=4.0,
         em_iterations=10,
     )
 
     it.U = copy.deepcopy(M_u)
-    it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
+    it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
 
     np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8)
     np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8)
@@ -417,14 +417,14 @@ def test_ISVTrainAndEnrolWithNumpy():
     ubm.variances = UBM_VAR.reshape((2, 3))
 
     it = ISVMachine(
-        ubm,
+        ubm=ubm,
         r_U=2,
         relevance_factor=4.0,
         em_iterations=10,
     )
 
     it.U = copy.deepcopy(M_u)
-    it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
+    it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
 
     np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8)
     np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8)
@@ -466,13 +466,13 @@ def test_JFATrainInitialize():
     it = JFAMachine(ubm, 2, 2, em_iterations=10)
     # first round
 
-    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
     u1 = it.U
     v1 = it.V
     d1 = it.D
 
     # second round
-    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
     u2 = it.U
     v2 = it.V
     d2 = it.D
@@ -493,15 +493,15 @@ def test_ISVTrainInitialize():
     ubm.variances = UBM_VAR.reshape((2, 3))
 
     # ISV
-    it = ISVMachine(ubm, 2, em_iterations=10)
+    it = ISVMachine(2, em_iterations=10, ubm=ubm)
     # it.rng = rng
 
-    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
     u1 = copy.deepcopy(it.U)
     d1 = copy.deepcopy(it.D)
 
     # second round
-    it.initialize(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
     u2 = it.U
     d2 = it.D
 
-- 
GitLab