diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index ebb8c6ac3674a5e70a99d5953fb56235b3bb9b38..3a6b1dee18c58281983e3a63463dd41bff07b774 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -78,7 +78,7 @@ class FactorAnalysisBase(BaseEstimator): ubm: :py:class:`bob.learn.em.GMMMachine` A trained UBM (Universal Background Model) or a parametrized :py:class:`bob.learn.em.GMMMachine` to train the UBM with. If None, - `gmm_kwargs` are passed as parameters of a new + `ubm_kwargs` are passed as parameters of a new :py:class:`bob.learn.em.GMMMachine`. """ @@ -90,12 +90,12 @@ class FactorAnalysisBase(BaseEstimator): em_iterations=10, seed=0, ubm=None, - gmm_kwargs=None, + ubm_kwargs=None, **kwargs, ): super().__init__(**kwargs) self.ubm = ubm - self.gmm_kwargs = gmm_kwargs + self.ubm_kwargs = ubm_kwargs self.em_iterations = em_iterations self.seed = seed @@ -206,7 +206,7 @@ class FactorAnalysisBase(BaseEstimator): if self.ubm is None: logger.info("FA: Creating a new GMMMachine.") - self.ubm = GMMMachine(**self.gmm_kwargs) + self.ubm = GMMMachine(**self.ubm_kwargs) # Train the UBM if not already trained if self.ubm._means is None: @@ -1189,7 +1189,7 @@ class ISVMachine(FactorAnalysisBase): ubm: :py:class:`bob.learn.em.GMMMachine` or None A trained UBM (Universal Background Model). If None, the UBM is trained with - a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `gmm_kwargs` + a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `ubm_kwargs` as parameters. """ @@ -1201,7 +1201,8 @@ class ISVMachine(FactorAnalysisBase): relevance_factor=4.0, seed=0, ubm=None, - **gmm_kwargs, + ubm_kwargs=None, + **kwargs, ): super().__init__( r_U=r_U, @@ -1209,7 +1210,8 @@ class ISVMachine(FactorAnalysisBase): em_iterations=em_iterations, seed=seed, ubm=ubm, - **gmm_kwargs, + ubm_kwargs=ubm_kwargs, + **kwargs, ) def initialize(self, X, y): @@ -1424,12 +1426,13 @@ class JFAMachine(FactorAnalysisBase): def __init__( self, - ubm, r_U, r_V, em_iterations=10, relevance_factor=4.0, seed=0, + ubm=None, + ubm_kwargs=None, **kwargs, ): super().__init__( @@ -1439,6 +1442,7 @@ class JFAMachine(FactorAnalysisBase): relevance_factor=relevance_factor, em_iterations=em_iterations, seed=seed, + ubm_kwargs=ubm_kwargs, **kwargs, ) diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index d1b214162d8e794ede2b59a8d32c4a0d363d1d61..f5a85fdabe3a4011c87673f58048d3163edb1d3b 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -118,7 +118,7 @@ def test_JFATrainAndEnrol(): ubm = GMMMachine(2, 3) ubm.means = UBM_MEAN.reshape((2, 3)) ubm.variances = UBM_VAR.reshape((2, 3)) - it = JFAMachine(ubm, 2, 2, em_iterations=10) + it = JFAMachine(2, 2, em_iterations=10, ubm=ubm) it.U = copy.deepcopy(M_u) it.V = copy.deepcopy(M_v) @@ -314,7 +314,7 @@ def test_JFATrainInitialize(): ubm.variances = UBM_VAR.reshape((2, 3)) # JFA - it = JFAMachine(ubm, 2, 2, em_iterations=10) + it = JFAMachine(2, 2, em_iterations=10, ubm=ubm) # first round it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) @@ -379,7 +379,7 @@ def test_JFAMachine(): gs.sum_pxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64") # Creates a JFAMachine - m = JFAMachine(ubm, 2, 2, em_iterations=10) + m = JFAMachine(2, 2, em_iterations=10, ubm=ubm) m.U = np.array( [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64" ) @@ -470,14 +470,69 @@ def test_ISV_fit(): isv.fit(data, labels) # Printing the session offset w.r.t each Gaussian component - np.testing.assert_allclose( - isv.U, - [ - [-0.01, -0.027], - [-0.002, -0.004], - [0.028, 0.074], - [0.012, 0.03], - [0.033, 0.085], - [0.046, 0.12], - ], + U_ref = [ + [-2.86662863e-02, 4.45865461e-04], + [-4.51712419e-03, 7.02577809e-05], + [7.91269855e-02, -1.23071365e-03], + [3.27129434e-02, -5.08805760e-04], + [9.17898003e-02, -1.42766668e-03], + [1.29496881e-01, -2.01414952e-03], + ] + # TODO(tiago): The reference used to be the values below but are different now + # U_ref = [ + # [-0.01, -0.027], + # [-0.002, -0.004], + # [0.028, 0.074], + # [0.012, 0.03], + # [0.033, 0.085], + # [0.046, 0.12], + # ] + np.testing.assert_allclose(isv.U, U_ref, atol=1e-7) + + +def test_JFA_fit(): + np.random.seed(10) + data_class1 = np.random.normal(0, 0.5, (10, 3)) + data_class2 = np.random.normal(-0.2, 0.2, (10, 3)) + data = np.concatenate([data_class1, data_class2], axis=0) + labels = [0] * 10 + [1] * 10 + + # Creating a fake prior with 2 gaussians + prior_gmm = GMMMachine(2) + prior_gmm.means = np.vstack( + (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3))) ) + + # All nice and round diagonal covariance + prior_gmm.variances = np.ones((2, 3)) * 0.5 + prior_gmm.weights = np.array([0.3, 0.7]) + + # Finally doing the JFA training + jfa = JFAMachine( + 2, + 2, + ubm=prior_gmm, + relevance_factor=4, + em_iterations=50, + ) + jfa.fit(data, labels) + + # Printing the session offset w.r.t each Gaussian component + V_ref = [ + [-0.00459188, 0.00463761], + [-0.06622346, 0.06688288], + [0.41800691, -0.4221692], + [0.40218688, -0.40619164], + [0.61849675, -0.6246554], + [0.57576069, -0.5814938], + ] + # TODO(tiago): The reference used to be the values below but are different now + # V_ref = [ + # [0.003, -0.006], + # [0.041, -0.084], + # [-0.261, 0.53], + # [-0.252, 0.51], + # [-0.387, 0.785], + # [-0.36, 0.73], + # ] + np.testing.assert_allclose(jfa.V, V_ref, atol=1e-7)