From 5ce317d2f234b5eff6a2035253d2a123c1262b1c Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Mon, 11 Apr 2022 14:30:29 +0200 Subject: [PATCH] [factor_analysis] Still allow fit and init using gmm stats --- bob/learn/em/factor_analysis.py | 41 +++++++++++++++++---------- bob/learn/em/test/test_jfa.py | 2 +- bob/learn/em/test/test_jfa_trainer.py | 22 +++++++------- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 18cdaa0..184c5b7 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -104,6 +104,9 @@ class FactorAnalysisBase(BaseEstimator): self.relevance_factor = relevance_factor + if ubm is not None: + self.create_UVD() + @property def feature_dimension(self): """Get the UBM Dimension""" @@ -216,6 +219,9 @@ class FactorAnalysisBase(BaseEstimator): if not hasattr(self, "_U") or not hasattr(self, "_D"): self.create_UVD() + self.initialize_using_stats(ubm_projected_X, y) + + def initialize_using_stats(self, ubm_projected_X, y): # Accumulating 0th and 1st order statistics # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68 # 0th order stats @@ -814,7 +820,7 @@ class FactorAnalysisBase(BaseEstimator): np.zeros( ( self.r_U, - y.count(y_i), + np.sum(y == y_i), ) ) ) @@ -1192,7 +1198,7 @@ class ISVMachine(FactorAnalysisBase): ubm=None, **gmm_kwargs, ): - super(ISVMachine, self).__init__( + super().__init__( r_U=r_U, relevance_factor=relevance_factor, em_iterations=em_iterations, @@ -1213,7 +1219,7 @@ class ISVMachine(FactorAnalysisBase): y: np.ndarray of shape(n_clients,) Client labels. """ - return super(ISVMachine, self).initialize(X, y) + return super().initialize(X, y) def e_step(self, X, y, n_acc, f_acc): """ @@ -1252,7 +1258,7 @@ class ISVMachine(FactorAnalysisBase): self.update_U(acc_U_A1, acc_U_A2) - def fit(self, X, y): + def fit_using_stats(self, X, y): """ Trains the U matrix (session variability matrix) @@ -1270,10 +1276,10 @@ class ISVMachine(FactorAnalysisBase): """ - y = np.array(y).tolist() if not isinstance(y, list) else y + y = np.asarray(y) # TODO: Point of MAP-REDUCE - n_acc, f_acc = self.initialize(X, y) + n_acc, f_acc = self.initialize_using_stats(X, y) for i in range(self.em_iterations): logger.info("U Training: Iteration %d", i) # TODO: Point of MAP-REDUCE @@ -1412,20 +1418,25 @@ class JFAMachine(FactorAnalysisBase): """ def __init__( - self, ubm, r_U, r_V, em_iterations=10, relevance_factor=4.0, seed=0 + self, + ubm, + r_U, + r_V, + em_iterations=10, + relevance_factor=4.0, + seed=0, + **kwargs, ): - super(JFAMachine, self).__init__( - ubm, + super().__init__( + ubm=ubm, r_U=r_U, r_V=r_V, relevance_factor=relevance_factor, em_iterations=em_iterations, seed=seed, + **kwargs, ) - def initialize(self, X, y): - return super(JFAMachine, self).initialize(X, y) - def e_step_v(self, X, y, n_acc, f_acc): """ ISV E-step for the V matrix. @@ -1769,7 +1780,7 @@ class JFAMachine(FactorAnalysisBase): """ return self.enroll([self.ubm.transform(X)], iterations) - def fit(self, X, y): + def fit_using_stats(self, X, y): """ Trains the U matrix (session variability matrix) @@ -1795,10 +1806,10 @@ class JFAMachine(FactorAnalysisBase): ): self.create_UVD() - y = np.array(y).tolist() if not isinstance(y, list) else y + y = np.asarray(y) # TODO: Point of MAP-REDUCE - n_acc, f_acc = self.initialize(X, y) + n_acc, f_acc = self.initialize_using_stats(X, y) # Updating V for i in range(self.em_iterations): diff --git a/bob/learn/em/test/test_jfa.py b/bob/learn/em/test/test_jfa.py index 8e20b4a..dc15b71 100644 --- a/bob/learn/em/test/test_jfa.py +++ b/bob/learn/em/test/test_jfa.py @@ -65,7 +65,7 @@ def test_ISVMachine(): ubm.variances = np.array([[1, 2, 1], [2, 1, 2]], "float64") # Creates a ISVMachine - isv_machine = ISVMachine(ubm, r_U=2, em_iterations=10) + isv_machine = ISVMachine(ubm=ubm, r_U=2, em_iterations=10) isv_machine.U = np.array( [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64" ) diff --git a/bob/learn/em/test/test_jfa_trainer.py b/bob/learn/em/test/test_jfa_trainer.py index cc3d6ad..e38beec 100644 --- a/bob/learn/em/test/test_jfa_trainer.py +++ b/bob/learn/em/test/test_jfa_trainer.py @@ -126,7 +126,7 @@ def test_JFATrainAndEnrol(): it.U = copy.deepcopy(M_u) it.V = copy.deepcopy(M_v) it.D = copy.deepcopy(M_d) - it.fit(TRAINING_STATS_X, TRAINING_STATS_y) + it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) v_ref = np.array( [ @@ -225,7 +225,7 @@ def test_JFATrainAndEnrolWithNumpy(): it.U = copy.deepcopy(M_u) it.V = copy.deepcopy(M_v) it.D = copy.deepcopy(M_d) - it.fit(TRAINING_STATS_X, TRAINING_STATS_y) + it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) v_ref = np.array( [ @@ -337,14 +337,14 @@ def test_ISVTrainAndEnrol(): ubm.variances = UBM_VAR.reshape((2, 3)) it = ISVMachine( - ubm, + ubm=ubm, r_U=2, relevance_factor=4.0, em_iterations=10, ) it.U = copy.deepcopy(M_u) - it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y) + it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8) np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8) @@ -417,14 +417,14 @@ def test_ISVTrainAndEnrolWithNumpy(): ubm.variances = UBM_VAR.reshape((2, 3)) it = ISVMachine( - ubm, + ubm=ubm, r_U=2, relevance_factor=4.0, em_iterations=10, ) it.U = copy.deepcopy(M_u) - it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y) + it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8) np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8) @@ -466,13 +466,13 @@ def test_JFATrainInitialize(): it = JFAMachine(ubm, 2, 2, em_iterations=10) # first round - it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) u1 = it.U v1 = it.V d1 = it.D # second round - it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) u2 = it.U v2 = it.V d2 = it.D @@ -493,15 +493,15 @@ def test_ISVTrainInitialize(): ubm.variances = UBM_VAR.reshape((2, 3)) # ISV - it = ISVMachine(ubm, 2, em_iterations=10) + it = ISVMachine(2, em_iterations=10, ubm=ubm) # it.rng = rng - it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) u1 = copy.deepcopy(it.U) d1 = copy.deepcopy(it.D) # second round - it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) u2 = it.U d2 = it.D -- GitLab