Skip to content
Snippets Groups Projects
Commit 5ce317d2 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[factor_analysis] Still allow fit and init using gmm stats

parent 482d4273
Branches
Tags
1 merge request!53Factor Analysis on pure python
Pipeline #60142 failed
...@@ -104,6 +104,9 @@ class FactorAnalysisBase(BaseEstimator): ...@@ -104,6 +104,9 @@ class FactorAnalysisBase(BaseEstimator):
self.relevance_factor = relevance_factor self.relevance_factor = relevance_factor
if ubm is not None:
self.create_UVD()
@property @property
def feature_dimension(self): def feature_dimension(self):
"""Get the UBM Dimension""" """Get the UBM Dimension"""
...@@ -216,6 +219,9 @@ class FactorAnalysisBase(BaseEstimator): ...@@ -216,6 +219,9 @@ class FactorAnalysisBase(BaseEstimator):
if not hasattr(self, "_U") or not hasattr(self, "_D"): if not hasattr(self, "_U") or not hasattr(self, "_D"):
self.create_UVD() self.create_UVD()
self.initialize_using_stats(ubm_projected_X, y)
def initialize_using_stats(self, ubm_projected_X, y):
# Accumulating 0th and 1st order statistics # Accumulating 0th and 1st order statistics
# https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68 # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
# 0th order stats # 0th order stats
...@@ -814,7 +820,7 @@ class FactorAnalysisBase(BaseEstimator): ...@@ -814,7 +820,7 @@ class FactorAnalysisBase(BaseEstimator):
np.zeros( np.zeros(
( (
self.r_U, self.r_U,
y.count(y_i), np.sum(y == y_i),
) )
) )
) )
...@@ -1192,7 +1198,7 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1192,7 +1198,7 @@ class ISVMachine(FactorAnalysisBase):
ubm=None, ubm=None,
**gmm_kwargs, **gmm_kwargs,
): ):
super(ISVMachine, self).__init__( super().__init__(
r_U=r_U, r_U=r_U,
relevance_factor=relevance_factor, relevance_factor=relevance_factor,
em_iterations=em_iterations, em_iterations=em_iterations,
...@@ -1213,7 +1219,7 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1213,7 +1219,7 @@ class ISVMachine(FactorAnalysisBase):
y: np.ndarray of shape(n_clients,) y: np.ndarray of shape(n_clients,)
Client labels. Client labels.
""" """
return super(ISVMachine, self).initialize(X, y) return super().initialize(X, y)
def e_step(self, X, y, n_acc, f_acc): def e_step(self, X, y, n_acc, f_acc):
""" """
...@@ -1252,7 +1258,7 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1252,7 +1258,7 @@ class ISVMachine(FactorAnalysisBase):
self.update_U(acc_U_A1, acc_U_A2) self.update_U(acc_U_A1, acc_U_A2)
def fit(self, X, y): def fit_using_stats(self, X, y):
""" """
Trains the U matrix (session variability matrix) Trains the U matrix (session variability matrix)
...@@ -1270,10 +1276,10 @@ class ISVMachine(FactorAnalysisBase): ...@@ -1270,10 +1276,10 @@ class ISVMachine(FactorAnalysisBase):
""" """
y = np.array(y).tolist() if not isinstance(y, list) else y y = np.asarray(y)
# TODO: Point of MAP-REDUCE # TODO: Point of MAP-REDUCE
n_acc, f_acc = self.initialize(X, y) n_acc, f_acc = self.initialize_using_stats(X, y)
for i in range(self.em_iterations): for i in range(self.em_iterations):
logger.info("U Training: Iteration %d", i) logger.info("U Training: Iteration %d", i)
# TODO: Point of MAP-REDUCE # TODO: Point of MAP-REDUCE
...@@ -1412,20 +1418,25 @@ class JFAMachine(FactorAnalysisBase): ...@@ -1412,20 +1418,25 @@ class JFAMachine(FactorAnalysisBase):
""" """
def __init__( def __init__(
self, ubm, r_U, r_V, em_iterations=10, relevance_factor=4.0, seed=0 self,
ubm,
r_U,
r_V,
em_iterations=10,
relevance_factor=4.0,
seed=0,
**kwargs,
): ):
super(JFAMachine, self).__init__( super().__init__(
ubm, ubm=ubm,
r_U=r_U, r_U=r_U,
r_V=r_V, r_V=r_V,
relevance_factor=relevance_factor, relevance_factor=relevance_factor,
em_iterations=em_iterations, em_iterations=em_iterations,
seed=seed, seed=seed,
**kwargs,
) )
def initialize(self, X, y):
return super(JFAMachine, self).initialize(X, y)
def e_step_v(self, X, y, n_acc, f_acc): def e_step_v(self, X, y, n_acc, f_acc):
""" """
ISV E-step for the V matrix. ISV E-step for the V matrix.
...@@ -1769,7 +1780,7 @@ class JFAMachine(FactorAnalysisBase): ...@@ -1769,7 +1780,7 @@ class JFAMachine(FactorAnalysisBase):
""" """
return self.enroll([self.ubm.transform(X)], iterations) return self.enroll([self.ubm.transform(X)], iterations)
def fit(self, X, y): def fit_using_stats(self, X, y):
""" """
Trains the U matrix (session variability matrix) Trains the U matrix (session variability matrix)
...@@ -1795,10 +1806,10 @@ class JFAMachine(FactorAnalysisBase): ...@@ -1795,10 +1806,10 @@ class JFAMachine(FactorAnalysisBase):
): ):
self.create_UVD() self.create_UVD()
y = np.array(y).tolist() if not isinstance(y, list) else y y = np.asarray(y)
# TODO: Point of MAP-REDUCE # TODO: Point of MAP-REDUCE
n_acc, f_acc = self.initialize(X, y) n_acc, f_acc = self.initialize_using_stats(X, y)
# Updating V # Updating V
for i in range(self.em_iterations): for i in range(self.em_iterations):
......
...@@ -65,7 +65,7 @@ def test_ISVMachine(): ...@@ -65,7 +65,7 @@ def test_ISVMachine():
ubm.variances = np.array([[1, 2, 1], [2, 1, 2]], "float64") ubm.variances = np.array([[1, 2, 1], [2, 1, 2]], "float64")
# Creates a ISVMachine # Creates a ISVMachine
isv_machine = ISVMachine(ubm, r_U=2, em_iterations=10) isv_machine = ISVMachine(ubm=ubm, r_U=2, em_iterations=10)
isv_machine.U = np.array( isv_machine.U = np.array(
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64" [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64"
) )
......
...@@ -126,7 +126,7 @@ def test_JFATrainAndEnrol(): ...@@ -126,7 +126,7 @@ def test_JFATrainAndEnrol():
it.U = copy.deepcopy(M_u) it.U = copy.deepcopy(M_u)
it.V = copy.deepcopy(M_v) it.V = copy.deepcopy(M_v)
it.D = copy.deepcopy(M_d) it.D = copy.deepcopy(M_d)
it.fit(TRAINING_STATS_X, TRAINING_STATS_y) it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
v_ref = np.array( v_ref = np.array(
[ [
...@@ -225,7 +225,7 @@ def test_JFATrainAndEnrolWithNumpy(): ...@@ -225,7 +225,7 @@ def test_JFATrainAndEnrolWithNumpy():
it.U = copy.deepcopy(M_u) it.U = copy.deepcopy(M_u)
it.V = copy.deepcopy(M_v) it.V = copy.deepcopy(M_v)
it.D = copy.deepcopy(M_d) it.D = copy.deepcopy(M_d)
it.fit(TRAINING_STATS_X, TRAINING_STATS_y) it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
v_ref = np.array( v_ref = np.array(
[ [
...@@ -337,14 +337,14 @@ def test_ISVTrainAndEnrol(): ...@@ -337,14 +337,14 @@ def test_ISVTrainAndEnrol():
ubm.variances = UBM_VAR.reshape((2, 3)) ubm.variances = UBM_VAR.reshape((2, 3))
it = ISVMachine( it = ISVMachine(
ubm, ubm=ubm,
r_U=2, r_U=2,
relevance_factor=4.0, relevance_factor=4.0,
em_iterations=10, em_iterations=10,
) )
it.U = copy.deepcopy(M_u) it.U = copy.deepcopy(M_u)
it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y) it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8) np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8)
np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8) np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8)
...@@ -417,14 +417,14 @@ def test_ISVTrainAndEnrolWithNumpy(): ...@@ -417,14 +417,14 @@ def test_ISVTrainAndEnrolWithNumpy():
ubm.variances = UBM_VAR.reshape((2, 3)) ubm.variances = UBM_VAR.reshape((2, 3))
it = ISVMachine( it = ISVMachine(
ubm, ubm=ubm,
r_U=2, r_U=2,
relevance_factor=4.0, relevance_factor=4.0,
em_iterations=10, em_iterations=10,
) )
it.U = copy.deepcopy(M_u) it.U = copy.deepcopy(M_u)
it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y) it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8) np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8)
np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8) np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8)
...@@ -466,13 +466,13 @@ def test_JFATrainInitialize(): ...@@ -466,13 +466,13 @@ def test_JFATrainInitialize():
it = JFAMachine(ubm, 2, 2, em_iterations=10) it = JFAMachine(ubm, 2, 2, em_iterations=10)
# first round # first round
it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
u1 = it.U u1 = it.U
v1 = it.V v1 = it.V
d1 = it.D d1 = it.D
# second round # second round
it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
u2 = it.U u2 = it.U
v2 = it.V v2 = it.V
d2 = it.D d2 = it.D
...@@ -493,15 +493,15 @@ def test_ISVTrainInitialize(): ...@@ -493,15 +493,15 @@ def test_ISVTrainInitialize():
ubm.variances = UBM_VAR.reshape((2, 3)) ubm.variances = UBM_VAR.reshape((2, 3))
# ISV # ISV
it = ISVMachine(ubm, 2, em_iterations=10) it = ISVMachine(2, em_iterations=10, ubm=ubm)
# it.rng = rng # it.rng = rng
it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
u1 = copy.deepcopy(it.U) u1 = copy.deepcopy(it.U)
d1 = copy.deepcopy(it.D) d1 = copy.deepcopy(it.D)
# second round # second round
it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
u2 = it.U u2 = it.U
d2 = it.D d2 = it.D
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment