diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 820a1ee786f42a9cf66fb7419ea43343edf4c887..1d0b81edff2f27fdfa0e74ec2d7c05913b6c2896 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -362,10 +362,9 @@ class FactorAnalysisBase(BaseEstimator): E[y_i] for class `i` """ - f_i = x_i.sum_px n_i = x_i.n - n_ic = np.repeat(n_i, self.supervector_dimension // 2) + n_ic = np.repeat(n_i, self.feature_dimension) V = self._V ## N_ih*( m + D*z) @@ -644,7 +643,7 @@ class FactorAnalysisBase(BaseEstimator): """ - tmp_CD = np.repeat(n_acc_i, self.supervector_dimension // 2) + tmp_CD = np.repeat(n_acc_i, self.feature_dimension) id_plus_d_prod = np.ones(tmp_CD.shape) + dt_inv_sigma_d * tmp_CD return 1 / id_plus_d_prod @@ -664,7 +663,7 @@ class FactorAnalysisBase(BaseEstimator): m = self.mean_supervector - tmp_CD = np.repeat(n_acc_i, self.supervector_dimension // 2) + tmp_CD = np.repeat(n_acc_i, self.feature_dimension) ## JFA session part V_dot_v = V @ latent_y_i if latent_y_i is not None else 0 @@ -675,7 +674,7 @@ class FactorAnalysisBase(BaseEstimator): # Looping over the sessions for session_id in range(len(X_i)): n_i = X_i[session_id].n - tmp_CD = np.repeat(n_i, self.supervector_dimension // 2) + tmp_CD = np.repeat(n_i, self.feature_dimension) x_i_h = latent_x_i[:, session_id] fn_z_i -= tmp_CD * (U @ x_i_h) @@ -757,7 +756,7 @@ class FactorAnalysisBase(BaseEstimator): X_i, latent_x_i, latent_y_i, n_acc[y_i], f_acc[y_i] ) - tmp_CD = np.repeat(n_acc[y_i], self.supervector_dimension // 2) + tmp_CD = np.repeat(n_acc[y_i], self.feature_dimension) acc_D_A1 += ( id_plus_d_prod + latent_z[y_i] * latent_z[y_i] ) * tmp_CD @@ -1024,7 +1023,7 @@ class FactorAnalysisBase(BaseEstimator): # y = self.y[i] # Not doing JFA - tmp_CD = np.repeat(n_acc_i, self.supervector_dimension // 2) + tmp_CD = np.repeat(n_acc_i, self.feature_dimension) fn_y_i = f_acc_i.flatten() - tmp_CD * ( m - D * latent_z_i @@ -1036,7 +1035,7 @@ class FactorAnalysisBase(BaseEstimator): for session_id in range(len(X_i)): n_i = X_i[session_id].n U_dot_x = U @ latent_x_i[:, session_id] - tmp_CD = np.repeat(n_i, self.supervector_dimension // 2) + tmp_CD = np.repeat(n_i, self.feature_dimension) fn_y_i -= tmp_CD * U_dot_x return fn_y_i @@ -1130,7 +1129,9 @@ class ISVMachine(FactorAnalysisBase): """ - def __init__(self, ubm, r_U, em_iterations=10, relevance_factor=4.0, seed=0): + def __init__( + self, ubm, r_U, em_iterations=10, relevance_factor=4.0, seed=0 + ): super(ISVMachine, self).__init__( ubm, r_U=r_U, @@ -1201,6 +1202,8 @@ class ISVMachine(FactorAnalysisBase): if not hasattr(self, "_U") or not hasattr(self, "_D"): self.create_UVD() + y = y.tolist() if not isinstance(y, list) else y + # TODO: Point of parallelism n_acc, f_acc = self.initialize(X, y) for i in range(self.em_iterations): @@ -1675,6 +1678,8 @@ class JFAMachine(FactorAnalysisBase): ): self.create_UVD() + y = y.tolist() if not isinstance(y, list) else y + # TODO: Point of parallelism n_acc, f_acc = self.initialize(X, y) diff --git a/doc/plot/plot_ISV.py b/doc/plot/plot_ISV.py new file mode 100644 index 0000000000000000000000000000000000000000..91b9ff61fdee5cc709c3797f3d5d753ddaf6780d --- /dev/null +++ b/doc/plot/plot_ISV.py @@ -0,0 +1,141 @@ +from sklearn.datasets import load_iris + +import bob.learn.em +import matplotlib.pyplot as plt +import numpy as np + +np.random.seed(2) # FIXING A SEED + + +def isv_train(features, ubm): + """ + Train U matrix + + **Parameters** + features: List of :py:class:`bob.learn.em.GMMStats` organized by class + + n_gaussians: UBM (:py:class:`bob.learn.em.GMMMachine`) + + """ + + stats = [] + for user in features: + user_stats = [] + for f in user: + s = bob.learn.em.GMMStats(ubm.shape[0], ubm.shape[1]) + ubm.acc_statistics(f, s) + user_stats.append(s) + stats.append(user_stats) + + relevance_factor = 4 + subspace_dimension_of_u = 1 + + isvbase = bob.learn.em.ISVBase(ubm, subspace_dimension_of_u) + trainer = bob.learn.em.ISVTrainer(relevance_factor) + # trainer.rng = bob.core.random.mt19937(int(self.init_seed)) + bob.learn.em.train(trainer, isvbase, stats, max_iterations=50) + + return isvbase + + +# GENERATING DATA +iris_data = load_iris() +X = iris_data.data +y = iris_data.target + +setosa = X[iris_data.target == 0] +versicolor = X[iris_data.target == 1] +virginica = X[iris_data.target == 2] + +n_gaussians = 3 +r_U = 1 + + +# TRAINING THE PRIOR +ubm = bob.learn.em.GMMMachine(n_gaussians).fit(X) + +gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X] +isv_machine = bob.learn.em.ISVMachine(ubm, r_U).fit(gmm_stats, y) + +# isvbase = isv_train([setosa, versicolor, virginica], ubm) + +# Variability direction +u0 = isv_machine.U[0:2, 0] / np.linalg.norm(isv_machine.U[0:2, 0]) +u1 = isv_machine.U[2:4, 0] / np.linalg.norm(isv_machine.U[2:4, 0]) +u2 = isv_machine.U[4:6, 0] / np.linalg.norm(isv_machine.U[4:6, 0]) + +figure, ax = plt.subplots() +plt.scatter(setosa[:, 0], setosa[:, 1], c="darkcyan", label="setosa") +plt.scatter( + versicolor[:, 0], versicolor[:, 1], c="goldenrod", label="versicolor" +) +plt.scatter(virginica[:, 0], virginica[:, 1], c="dimgrey", label="virginica") + +plt.scatter( + ubm.means[:, 0], + ubm.means[:, 1], + c="blue", + marker="x", + label="centroids - mle", +) +# plt.scatter(ubm.means[:, 0], ubm.means[:, 1], c="blue", +# marker=".", label="within class varibility", s=0.01) + +ax.arrow( + ubm.means[0, 0], + ubm.means[0, 1], + u0[0], + u0[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[1, 0], + ubm.means[1, 1], + u1[0], + u1[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[2, 0], + ubm.means[2, 1], + u2[0], + u2[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +plt.text( + ubm.means[0, 0] + u0[0], + ubm.means[0, 1] + u0[1] - 0.1, + r"$\mathbf{U}_1$", + fontsize=15, +) +plt.text( + ubm.means[1, 0] + u1[0], + ubm.means[1, 1] + u1[1] - 0.1, + r"$\mathbf{U}_2$", + fontsize=15, +) +plt.text( + ubm.means[2, 0] + u2[0], + ubm.means[2, 1] + u2[1] - 0.1, + r"$\mathbf{U}_3$", + fontsize=15, +) + +plt.xticks([], []) +plt.yticks([], []) + +# plt.grid(True) +plt.xlabel("Sepal length") +plt.ylabel("Petal width") +plt.legend() +plt.tight_layout() +plt.show() diff --git a/doc/plot/plot_JFA.py b/doc/plot/plot_JFA.py new file mode 100644 index 0000000000000000000000000000000000000000..f77e6779c9f6f9b52ac78e72d76b498b8ec61d92 --- /dev/null +++ b/doc/plot/plot_JFA.py @@ -0,0 +1,201 @@ +from sklearn.datasets import load_iris + +import bob.learn.em +import matplotlib.pyplot as plt +import numpy as np + +np.random.seed(2) # FIXING A SEED + + +def isv_train(features, ubm): + """ + Train U matrix + + **Parameters** + features: List of :py:class:`bob.learn.em.GMMStats` organized by class + + n_gaussians: UBM (:py:class:`bob.learn.em.GMMMachine`) + + """ + + stats = [] + for user in features: + user_stats = [] + for f in user: + s = bob.learn.em.GMMStats(ubm.shape[0], ubm.shape[1]) + ubm.acc_statistics(f, s) + user_stats.append(s) + stats.append(user_stats) + + relevance_factor = 4 + subspace_dimension_of_u = 1 + + isvbase = bob.learn.em.ISVBase(ubm, subspace_dimension_of_u) + trainer = bob.learn.em.ISVTrainer(relevance_factor) + # trainer.rng = bob.core.random.mt19937(int(self.init_seed)) + bob.learn.em.train(trainer, isvbase, stats, max_iterations=50) + + return isvbase + + +# GENERATING DATA +iris_data = load_iris() +X = iris_data.data +y = iris_data.target + +setosa = X[iris_data.target == 0] +versicolor = X[iris_data.target == 1] +virginica = X[iris_data.target == 2] + +n_gaussians = 3 +r_U = 1 +r_V = 1 + + +# TRAINING THE PRIOR +ubm = bob.learn.em.GMMMachine(n_gaussians).fit(X) + +gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X] +jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V).fit(gmm_stats, y) + + +# Variability direction U +u0 = jfa_machine.U[0:2, 0] / np.linalg.norm(jfa_machine.U[0:2, 0]) +u1 = jfa_machine.U[2:4, 0] / np.linalg.norm(jfa_machine.U[2:4, 0]) +u2 = jfa_machine.U[4:6, 0] / np.linalg.norm(jfa_machine.U[4:6, 0]) + + +# Variability direction V +v0 = jfa_machine.V[0:2, 0] / np.linalg.norm(jfa_machine.V[0:2, 0]) +v1 = jfa_machine.V[2:4, 0] / np.linalg.norm(jfa_machine.V[2:4, 0]) +v2 = jfa_machine.V[4:6, 0] / np.linalg.norm(jfa_machine.V[4:6, 0]) + + +figure, ax = plt.subplots() +plt.scatter(setosa[:, 0], setosa[:, 1], c="darkcyan", label="setosa") +plt.scatter( + versicolor[:, 0], versicolor[:, 1], c="goldenrod", label="versicolor" +) +plt.scatter(virginica[:, 0], virginica[:, 1], c="dimgrey", label="virginica") + +plt.scatter( + ubm.means[:, 0], + ubm.means[:, 1], + c="blue", + marker="x", + label="centroids - mle", +) +# plt.scatter(ubm.means[:, 0], ubm.means[:, 1], c="blue", +# marker=".", label="within class varibility", s=0.01) + +# U +ax.arrow( + ubm.means[0, 0], + ubm.means[0, 1], + u0[0], + u0[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[1, 0], + ubm.means[1, 1], + u1[0], + u1[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[2, 0], + ubm.means[2, 1], + u2[0], + u2[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +plt.text( + ubm.means[0, 0] + u0[0], + ubm.means[0, 1] + u0[1] - 0.1, + r"$\mathbf{U}_1$", + fontsize=15, +) +plt.text( + ubm.means[1, 0] + u1[0], + ubm.means[1, 1] + u1[1] - 0.1, + r"$\mathbf{U}_2$", + fontsize=15, +) +plt.text( + ubm.means[2, 0] + u2[0], + ubm.means[2, 1] + u2[1] - 0.1, + r"$\mathbf{U}_3$", + fontsize=15, +) + +# V +ax.arrow( + ubm.means[0, 0], + ubm.means[0, 1], + v0[0], + v0[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[1, 0], + ubm.means[1, 1], + v1[0], + v1[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[2, 0], + ubm.means[2, 1], + v2[0], + v2[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +plt.text( + ubm.means[0, 0] + v0[0], + ubm.means[0, 1] + v0[1] - 0.1, + r"$\mathbf{V}_1$", + fontsize=15, +) +plt.text( + ubm.means[1, 0] + v1[0], + ubm.means[1, 1] + v1[1] - 0.1, + r"$\mathbf{V}_2$", + fontsize=15, +) +plt.text( + ubm.means[2, 0] + v2[0], + ubm.means[2, 1] + v2[1] - 0.1, + r"$\mathbf{V}_3$", + fontsize=15, +) + +plt.xticks([], []) +plt.yticks([], []) + +# plt.grid(True) +plt.xlabel("Sepal length") +plt.ylabel("Petal width") +plt.legend(loc=2) +# plt.ylim([-1, 3.5]) + +plt.tight_layout() +plt.show()