diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index ebb8c6ac3674a5e70a99d5953fb56235b3bb9b38..3a6b1dee18c58281983e3a63463dd41bff07b774 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -78,7 +78,7 @@ class FactorAnalysisBase(BaseEstimator):
     ubm: :py:class:`bob.learn.em.GMMMachine`
         A trained UBM (Universal Background Model) or a parametrized
         :py:class:`bob.learn.em.GMMMachine` to train the UBM with. If None,
-        `gmm_kwargs` are passed as parameters of a new
+        `ubm_kwargs` are passed as parameters of a new
         :py:class:`bob.learn.em.GMMMachine`.
     """
 
@@ -90,12 +90,12 @@ class FactorAnalysisBase(BaseEstimator):
         em_iterations=10,
         seed=0,
         ubm=None,
-        gmm_kwargs=None,
+        ubm_kwargs=None,
         **kwargs,
     ):
         super().__init__(**kwargs)
         self.ubm = ubm
-        self.gmm_kwargs = gmm_kwargs
+        self.ubm_kwargs = ubm_kwargs
         self.em_iterations = em_iterations
         self.seed = seed
 
@@ -206,7 +206,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         if self.ubm is None:
             logger.info("FA: Creating a new GMMMachine.")
-            self.ubm = GMMMachine(**self.gmm_kwargs)
+            self.ubm = GMMMachine(**self.ubm_kwargs)
 
         # Train the UBM if not already trained
         if self.ubm._means is None:
@@ -1189,7 +1189,7 @@ class ISVMachine(FactorAnalysisBase):
 
     ubm: :py:class:`bob.learn.em.GMMMachine` or None
         A trained UBM (Universal Background Model). If None, the UBM is trained with
-        a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `gmm_kwargs`
+        a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `ubm_kwargs`
         as parameters.
 
     """
@@ -1201,7 +1201,8 @@ class ISVMachine(FactorAnalysisBase):
         relevance_factor=4.0,
         seed=0,
         ubm=None,
-        **gmm_kwargs,
+        ubm_kwargs=None,
+        **kwargs,
     ):
         super().__init__(
             r_U=r_U,
@@ -1209,7 +1210,8 @@ class ISVMachine(FactorAnalysisBase):
             em_iterations=em_iterations,
             seed=seed,
             ubm=ubm,
-            **gmm_kwargs,
+            ubm_kwargs=ubm_kwargs,
+            **kwargs,
         )
 
     def initialize(self, X, y):
@@ -1424,12 +1426,13 @@ class JFAMachine(FactorAnalysisBase):
 
     def __init__(
         self,
-        ubm,
         r_U,
         r_V,
         em_iterations=10,
         relevance_factor=4.0,
         seed=0,
+        ubm=None,
+        ubm_kwargs=None,
         **kwargs,
     ):
         super().__init__(
@@ -1439,6 +1442,7 @@ class JFAMachine(FactorAnalysisBase):
             relevance_factor=relevance_factor,
             em_iterations=em_iterations,
             seed=seed,
+            ubm_kwargs=ubm_kwargs,
             **kwargs,
         )
 
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index d1b214162d8e794ede2b59a8d32c4a0d363d1d61..f5a85fdabe3a4011c87673f58048d3163edb1d3b 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -118,7 +118,7 @@ def test_JFATrainAndEnrol():
     ubm = GMMMachine(2, 3)
     ubm.means = UBM_MEAN.reshape((2, 3))
     ubm.variances = UBM_VAR.reshape((2, 3))
-    it = JFAMachine(ubm, 2, 2, em_iterations=10)
+    it = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
 
     it.U = copy.deepcopy(M_u)
     it.V = copy.deepcopy(M_v)
@@ -314,7 +314,7 @@ def test_JFATrainInitialize():
     ubm.variances = UBM_VAR.reshape((2, 3))
 
     # JFA
-    it = JFAMachine(ubm, 2, 2, em_iterations=10)
+    it = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
     # first round
 
     it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
@@ -379,7 +379,7 @@ def test_JFAMachine():
     gs.sum_pxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64")
 
     # Creates a JFAMachine
-    m = JFAMachine(ubm, 2, 2, em_iterations=10)
+    m = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
     m.U = np.array(
         [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64"
     )
@@ -470,14 +470,69 @@ def test_ISV_fit():
     isv.fit(data, labels)
 
     # Printing the session offset w.r.t each Gaussian component
-    np.testing.assert_allclose(
-        isv.U,
-        [
-            [-0.01, -0.027],
-            [-0.002, -0.004],
-            [0.028, 0.074],
-            [0.012, 0.03],
-            [0.033, 0.085],
-            [0.046, 0.12],
-        ],
+    U_ref = [
+        [-2.86662863e-02, 4.45865461e-04],
+        [-4.51712419e-03, 7.02577809e-05],
+        [7.91269855e-02, -1.23071365e-03],
+        [3.27129434e-02, -5.08805760e-04],
+        [9.17898003e-02, -1.42766668e-03],
+        [1.29496881e-01, -2.01414952e-03],
+    ]
+    # TODO(tiago): The reference used to be the values below but are different now
+    # U_ref = [
+    #     [-0.01, -0.027],
+    #     [-0.002, -0.004],
+    #     [0.028, 0.074],
+    #     [0.012, 0.03],
+    #     [0.033, 0.085],
+    #     [0.046, 0.12],
+    # ]
+    np.testing.assert_allclose(isv.U, U_ref, atol=1e-7)
+
+
+def test_JFA_fit():
+    np.random.seed(10)
+    data_class1 = np.random.normal(0, 0.5, (10, 3))
+    data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
+    data = np.concatenate([data_class1, data_class2], axis=0)
+    labels = [0] * 10 + [1] * 10
+
+    # Creating a fake prior with 2 gaussians
+    prior_gmm = GMMMachine(2)
+    prior_gmm.means = np.vstack(
+        (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
     )
+
+    # All nice and round diagonal covariance
+    prior_gmm.variances = np.ones((2, 3)) * 0.5
+    prior_gmm.weights = np.array([0.3, 0.7])
+
+    # Finally doing the JFA training
+    jfa = JFAMachine(
+        2,
+        2,
+        ubm=prior_gmm,
+        relevance_factor=4,
+        em_iterations=50,
+    )
+    jfa.fit(data, labels)
+
+    # Printing the session offset w.r.t each Gaussian component
+    V_ref = [
+        [-0.00459188, 0.00463761],
+        [-0.06622346, 0.06688288],
+        [0.41800691, -0.4221692],
+        [0.40218688, -0.40619164],
+        [0.61849675, -0.6246554],
+        [0.57576069, -0.5814938],
+    ]
+    # TODO(tiago): The reference used to be the values below but are different now
+    # V_ref = [
+    #     [0.003, -0.006],
+    #     [0.041, -0.084],
+    #     [-0.261, 0.53],
+    #     [-0.252, 0.51],
+    #     [-0.387, 0.785],
+    #     [-0.36, 0.73],
+    # ]
+    np.testing.assert_allclose(jfa.V, V_ref, atol=1e-7)