From 08ab75d7b3e95335eaa89e0ccf1ac18edd47cad6 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 11 Apr 2022 17:05:17 +0200
Subject: [PATCH] [factor_analysis] add a fit method using data not stats

---
 bob/learn/em/factor_analysis.py           |  7 +++-
 bob/learn/em/test/test_factor_analysis.py | 42 ++++++++++++++++++++++-
 2 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 1722053..ebb8c6a 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -8,6 +8,7 @@ import dask
 import numpy as np
 
 from sklearn.base import BaseEstimator
+from sklearn.utils.multiclass import unique_labels
 
 from .gmm import GMMMachine
 from .linear_scoring import linear_scoring
@@ -180,7 +181,7 @@ class FactorAnalysisBase(BaseEstimator):
         Estimates the number of classes given the labels
         """
 
-        return np.max(y) + 1
+        return len(unique_labels(y))
 
     def initialize(self, X, y):
         """
@@ -1157,6 +1158,10 @@ class FactorAnalysisBase(BaseEstimator):
 
         return self.score_using_stats(model, self.ubm.transform(data))
 
+    def fit(self, X, y):
+        stats = [self.ubm.transform(xx) for xx in X]
+        return self.fit_using_stats(stats, y)
+
 
 class ISVMachine(FactorAnalysisBase):
     """
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 3e03e95..d1b2141 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -418,7 +418,7 @@ def test_ISVMachine():
     isv_machine.U = np.array(
         [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64"
     )
-    # base.v = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64')
+    # base.v = np.array([[0], [0], [0], [0], [0], [0]], 'float64')
     isv_machine.D = np.array([0, 1, 0, 1, 0, 1], "float64")
 
     # Defines GMMStats
@@ -441,3 +441,43 @@ def test_ISVMachine():
     score_ref = -1.2343813195374242
     score = isv_machine.score(latent_z, X)
     np.testing.assert_allclose(score, score_ref, atol=eps)
+
+
+def test_ISV_fit():
+    np.random.seed(10)
+    data_class1 = np.random.normal(0, 0.5, (10, 3))
+    data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
+    data = np.concatenate([data_class1, data_class2], axis=0)
+    labels = [0] * 10 + [1] * 10
+
+    # Creating a fake prior with 2 gaussians
+    prior_gmm = GMMMachine(2)
+    prior_gmm.means = np.vstack(
+        (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3)))
+    )
+
+    # All nice and round diagonal covariance
+    prior_gmm.variances = np.ones((2, 3)) * 0.5
+    prior_gmm.weights = np.array([0.3, 0.7])
+
+    # Finally doing the ISV training
+    isv = ISVMachine(
+        2,
+        ubm=prior_gmm,
+        relevance_factor=4,
+        em_iterations=50,
+    )
+    isv.fit(data, labels)
+
+    # Printing the session offset w.r.t each Gaussian component
+    np.testing.assert_allclose(
+        isv.U,
+        [
+            [-0.01, -0.027],
+            [-0.002, -0.004],
+            [0.028, 0.074],
+            [0.012, 0.03],
+            [0.033, 0.085],
+            [0.046, 0.12],
+        ],
+    )
-- 
GitLab