diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 1722053250b398edbef8baa88ca4995e9d95968d..ebb8c6ac3674a5e70a99d5953fb56235b3bb9b38 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -8,6 +8,7 @@ import dask import numpy as np from sklearn.base import BaseEstimator +from sklearn.utils.multiclass import unique_labels from .gmm import GMMMachine from .linear_scoring import linear_scoring @@ -180,7 +181,7 @@ class FactorAnalysisBase(BaseEstimator): Estimates the number of classes given the labels """ - return np.max(y) + 1 + return len(unique_labels(y)) def initialize(self, X, y): """ @@ -1157,6 +1158,10 @@ class FactorAnalysisBase(BaseEstimator): return self.score_using_stats(model, self.ubm.transform(data)) + def fit(self, X, y): + stats = [self.ubm.transform(xx) for xx in X] + return self.fit_using_stats(stats, y) + class ISVMachine(FactorAnalysisBase): """ diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index 3e03e9531c4988533ae2a8e580ba9b63460798b4..d1b214162d8e794ede2b59a8d32c4a0d363d1d61 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -418,7 +418,7 @@ def test_ISVMachine(): isv_machine.U = np.array( [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64" ) - # base.v = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64') + # base.v = np.array([[0], [0], [0], [0], [0], [0]], 'float64') isv_machine.D = np.array([0, 1, 0, 1, 0, 1], "float64") # Defines GMMStats @@ -441,3 +441,43 @@ def test_ISVMachine(): score_ref = -1.2343813195374242 score = isv_machine.score(latent_z, X) np.testing.assert_allclose(score, score_ref, atol=eps) + + +def test_ISV_fit(): + np.random.seed(10) + data_class1 = np.random.normal(0, 0.5, (10, 3)) + data_class2 = np.random.normal(-0.2, 0.2, (10, 3)) + data = np.concatenate([data_class1, data_class2], axis=0) + labels = [0] * 10 + [1] * 10 + + # Creating a fake prior with 2 gaussians + prior_gmm = GMMMachine(2) + prior_gmm.means = np.vstack( + (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3))) + ) + + # All nice and round diagonal covariance + prior_gmm.variances = np.ones((2, 3)) * 0.5 + prior_gmm.weights = np.array([0.3, 0.7]) + + # Finally doing the ISV training + isv = ISVMachine( + 2, + ubm=prior_gmm, + relevance_factor=4, + em_iterations=50, + ) + isv.fit(data, labels) + + # Printing the session offset w.r.t each Gaussian component + np.testing.assert_allclose( + isv.U, + [ + [-0.01, -0.027], + [-0.002, -0.004], + [0.028, 0.074], + [0.012, 0.03], + [0.033, 0.085], + [0.046, 0.12], + ], + )