diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 6e8fdefbec11a2f6fcf9176a80bfa2a95901c9ad..b65da8263eb6190b0a55e038ac7837cfba347ebd 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -180,12 +180,7 @@ class FactorAnalysisBase(BaseEstimator):
 
     @U.setter
     def U(self, value):
-        U_shape = (self.supervector_dimension, self.r_U)
-        if value.shape != U_shape:
-            raise ValueError(
-                f"U must be a numpy array of shape {U_shape}, but a matrix of shape {value.shape} was provided."
-            )
-        self._U = value
+        self._U = np.array(value)
 
     @property
     def D(self):
@@ -194,12 +189,7 @@ class FactorAnalysisBase(BaseEstimator):
 
     @D.setter
     def D(self, value):
-        D_shape = (self.supervector_dimension,)
-        if value.shape != D_shape:
-            raise ValueError(
-                f"D must be a numpy array of shape {D_shape}, but a matrix of shape {value.shape} was provided."
-            )
-        self._D = value
+        self._D = np.array(value)
 
     @property
     def V(self):
@@ -208,12 +198,7 @@ class FactorAnalysisBase(BaseEstimator):
 
     @V.setter
     def V(self, value):
-        V_shape = (self.supervector_dimension, self.r_V)
-        if value.shape != V_shape:
-            raise ValueError(
-                f"V must be a numpy array of shape {V_shape}, but a matrix of shape {value.shape} was provided."
-            )
-        self._V = value
+        self._V = np.array(value)
 
     def estimate_number_of_classes(self, y):
         """
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 524ad18e42c504767cbe0ca80793352624f56a9c..d2399b26bfaa4f0711c67bea7c34340a8921b5a7 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -467,6 +467,32 @@ def test_ISV_JFA_fit():
     means = np.vstack(
         (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
     )
+    prior_U = [
+        [-0.150035, -0.44441],
+        [-1.67812, 2.47621],
+        [-0.52885, 0.659141],
+        [-0.538446, 1.67376],
+        [-0.111288, 2.06948],
+        [1.39563, -1.65004],
+    ]
+
+    prior_V = [
+        [0.732467, 0.281321],
+        [0.543212, -0.512974],
+        [1.04108, 0.835224],
+        [-0.363719, -0.324688],
+        [-1.21579, -0.905314],
+        [-0.993204, -0.121991],
+    ]
+
+    prior_D = [
+        0.943986,
+        -0.0900599,
+        -0.528103,
+        0.541502,
+        -0.717824,
+        0.463729,
+    ]
 
     for prior, machine_type, ref in [
         (
@@ -478,53 +504,37 @@ def test_ISV_JFA_fit():
             True,
             "isv",
             [
-                [-0.02361267, 0.0157274],
-                [-0.00372588, 0.00248165],
-                [0.06517179, -0.04340818],
-                [0.02694231, -0.01794513],
-                [0.07560949, -0.05036029],
-                [0.10668997, -0.07106169],
+                [-0.01018673, -0.0266506],
+                [-0.00160621, -0.00420217],
+                [0.02811705, 0.07356008],
+                [0.011624, 0.0304108],
+                [0.03261831, 0.08533629],
+                [0.04602191, 0.12040291],
             ],
-            # TODO(tiago): The reference used to be the values below but are different now
-            # [
-            #     [-0.01, -0.027],
-            #     [-0.002, -0.004],
-            #     [0.028, 0.074],
-            #     [0.012, 0.03],
-            #     [0.033, 0.085],
-            #     [0.046, 0.12],
-            # ]
         ),
         (
             None,
             "jfa",
             [
-                [-0.04687046, -0.06302095],
-                [-0.04380423, -0.05889816],
-                [-0.02083793, -0.0280182],
-                [-0.04728452, -0.06357768],
-                [-0.04371283, -0.05877527],
-                [-0.0203464, -0.0273573],
+                [-0.05673845, -0.0543068],
+                [-0.05302666, -0.05075409],
+                [-0.02522509, -0.02414402],
+                [-0.05723968, -0.05478655],
+                [-0.05291602, -0.05064819],
+                [-0.02463007, -0.0235745],
             ],
         ),
         (
             True,
             "jfa",
             [
-                [6.54547662e-03, 1.98699266e-04],
-                [9.48510389e-02, 2.87936736e-03],
-                [-5.98879972e-01, -1.81800375e-02],
-                [-5.76350228e-01, -1.74961082e-02],
-                [-8.86302168e-01, -2.69052355e-02],
-                [-8.25011907e-01, -2.50446636e-02],
+                [0.002881, -0.00584225],
+                [0.04143539, -0.08402497],
+                [-0.26149924, 0.53028251],
+                [-0.25156832, 0.51014406],
+                [-0.38687765, 0.78453174],
+                [-0.36015821, 0.73034858],
             ],
-            # TODO(tiago): The reference used to be the values below but are different now
-            #   [[ 0.003 -0.006]
-            #    [ 0.041 -0.084]
-            #    [-0.261  0.53 ]
-            #    [-0.252  0.51 ]
-            #    [-0.387  0.785]
-            #    [-0.36   0.73 ]]
         ),
     ]:
         ref = np.asarray(ref)
@@ -535,6 +545,8 @@ def test_ISV_JFA_fit():
 
             if prior is None:
                 ubm = None
+                # we still provide an initial UBM because KMeans training is not
+                # determenistic depending on inputting numpy or dask arrays
                 ubm_kwargs = dict(n_gaussians=2, ubm=_create_ubm_prior(means))
             else:
                 ubm = _create_ubm_prior(means)
@@ -550,9 +562,13 @@ def test_ISV_JFA_fit():
 
             if machine_type == "isv":
                 machine = ISVMachine(2, **machine_kwargs)
+                machine.U = prior_U
                 test_attr = "U"
             else:
                 machine = JFAMachine(2, 2, **machine_kwargs)
+                machine.U = prior_U
+                machine.V = prior_V
+                machine.D = prior_D
                 test_attr = "V"
 
             err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}"