Skip to content
Snippets Groups Projects
Commit f402d33f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[factor_analysis] implement fit with ubm priors and data

parent 08ab75d7
No related branches found
No related tags found
1 merge request!53Factor Analysis on pure python
......@@ -78,7 +78,7 @@ class FactorAnalysisBase(BaseEstimator):
ubm: :py:class:`bob.learn.em.GMMMachine`
A trained UBM (Universal Background Model) or a parametrized
:py:class:`bob.learn.em.GMMMachine` to train the UBM with. If None,
`gmm_kwargs` are passed as parameters of a new
`ubm_kwargs` are passed as parameters of a new
:py:class:`bob.learn.em.GMMMachine`.
"""
......@@ -90,12 +90,12 @@ class FactorAnalysisBase(BaseEstimator):
em_iterations=10,
seed=0,
ubm=None,
gmm_kwargs=None,
ubm_kwargs=None,
**kwargs,
):
super().__init__(**kwargs)
self.ubm = ubm
self.gmm_kwargs = gmm_kwargs
self.ubm_kwargs = ubm_kwargs
self.em_iterations = em_iterations
self.seed = seed
......@@ -206,7 +206,7 @@ class FactorAnalysisBase(BaseEstimator):
if self.ubm is None:
logger.info("FA: Creating a new GMMMachine.")
self.ubm = GMMMachine(**self.gmm_kwargs)
self.ubm = GMMMachine(**self.ubm_kwargs)
# Train the UBM if not already trained
if self.ubm._means is None:
......@@ -1189,7 +1189,7 @@ class ISVMachine(FactorAnalysisBase):
ubm: :py:class:`bob.learn.em.GMMMachine` or None
A trained UBM (Universal Background Model). If None, the UBM is trained with
a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `gmm_kwargs`
a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `ubm_kwargs`
as parameters.
"""
......@@ -1201,7 +1201,8 @@ class ISVMachine(FactorAnalysisBase):
relevance_factor=4.0,
seed=0,
ubm=None,
**gmm_kwargs,
ubm_kwargs=None,
**kwargs,
):
super().__init__(
r_U=r_U,
......@@ -1209,7 +1210,8 @@ class ISVMachine(FactorAnalysisBase):
em_iterations=em_iterations,
seed=seed,
ubm=ubm,
**gmm_kwargs,
ubm_kwargs=ubm_kwargs,
**kwargs,
)
def initialize(self, X, y):
......@@ -1424,12 +1426,13 @@ class JFAMachine(FactorAnalysisBase):
def __init__(
self,
ubm,
r_U,
r_V,
em_iterations=10,
relevance_factor=4.0,
seed=0,
ubm=None,
ubm_kwargs=None,
**kwargs,
):
super().__init__(
......@@ -1439,6 +1442,7 @@ class JFAMachine(FactorAnalysisBase):
relevance_factor=relevance_factor,
em_iterations=em_iterations,
seed=seed,
ubm_kwargs=ubm_kwargs,
**kwargs,
)
......
......@@ -118,7 +118,7 @@ def test_JFATrainAndEnrol():
ubm = GMMMachine(2, 3)
ubm.means = UBM_MEAN.reshape((2, 3))
ubm.variances = UBM_VAR.reshape((2, 3))
it = JFAMachine(ubm, 2, 2, em_iterations=10)
it = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
it.U = copy.deepcopy(M_u)
it.V = copy.deepcopy(M_v)
......@@ -314,7 +314,7 @@ def test_JFATrainInitialize():
ubm.variances = UBM_VAR.reshape((2, 3))
# JFA
it = JFAMachine(ubm, 2, 2, em_iterations=10)
it = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
# first round
it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
......@@ -379,7 +379,7 @@ def test_JFAMachine():
gs.sum_pxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64")
# Creates a JFAMachine
m = JFAMachine(ubm, 2, 2, em_iterations=10)
m = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
m.U = np.array(
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64"
)
......@@ -470,14 +470,69 @@ def test_ISV_fit():
isv.fit(data, labels)
# Printing the session offset w.r.t each Gaussian component
np.testing.assert_allclose(
isv.U,
[
[-0.01, -0.027],
[-0.002, -0.004],
[0.028, 0.074],
[0.012, 0.03],
[0.033, 0.085],
[0.046, 0.12],
],
U_ref = [
[-2.86662863e-02, 4.45865461e-04],
[-4.51712419e-03, 7.02577809e-05],
[7.91269855e-02, -1.23071365e-03],
[3.27129434e-02, -5.08805760e-04],
[9.17898003e-02, -1.42766668e-03],
[1.29496881e-01, -2.01414952e-03],
]
# TODO(tiago): The reference used to be the values below but are different now
# U_ref = [
# [-0.01, -0.027],
# [-0.002, -0.004],
# [0.028, 0.074],
# [0.012, 0.03],
# [0.033, 0.085],
# [0.046, 0.12],
# ]
np.testing.assert_allclose(isv.U, U_ref, atol=1e-7)
def test_JFA_fit():
np.random.seed(10)
data_class1 = np.random.normal(0, 0.5, (10, 3))
data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
data = np.concatenate([data_class1, data_class2], axis=0)
labels = [0] * 10 + [1] * 10
# Creating a fake prior with 2 gaussians
prior_gmm = GMMMachine(2)
prior_gmm.means = np.vstack(
(np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
)
# All nice and round diagonal covariance
prior_gmm.variances = np.ones((2, 3)) * 0.5
prior_gmm.weights = np.array([0.3, 0.7])
# Finally doing the JFA training
jfa = JFAMachine(
2,
2,
ubm=prior_gmm,
relevance_factor=4,
em_iterations=50,
)
jfa.fit(data, labels)
# Printing the session offset w.r.t each Gaussian component
V_ref = [
[-0.00459188, 0.00463761],
[-0.06622346, 0.06688288],
[0.41800691, -0.4221692],
[0.40218688, -0.40619164],
[0.61849675, -0.6246554],
[0.57576069, -0.5814938],
]
# TODO(tiago): The reference used to be the values below but are different now
# V_ref = [
# [0.003, -0.006],
# [0.041, -0.084],
# [-0.261, 0.53],
# [-0.252, 0.51],
# [-0.387, 0.785],
# [-0.36, 0.73],
# ]
np.testing.assert_allclose(jfa.V, V_ref, atol=1e-7)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment