From 90d7585c12c0ef45b104be2573ae8bda0c267786 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 22 Apr 2022 13:58:25 +0200
Subject: [PATCH] [doc] fix JFA plot script

---
 doc/plot/plot_JFA.py | 43 ++-----------------------------------------
 1 file changed, 2 insertions(+), 41 deletions(-)

diff --git a/doc/plot/plot_JFA.py b/doc/plot/plot_JFA.py
index 6cfa4bc..a56ab32 100644
--- a/doc/plot/plot_JFA.py
+++ b/doc/plot/plot_JFA.py
@@ -8,37 +8,6 @@ import bob.learn.em
 np.random.seed(2)  # FIXING A SEED
 
 
-def isv_train(features, ubm):
-    """
-    Train U matrix
-
-    **Parameters**
-      features: List of :py:class:`bob.learn.em.GMMStats` organized by class
-
-      n_gaussians: UBM (:py:class:`bob.learn.em.GMMMachine`)
-
-    """
-
-    stats = []
-    for user in features:
-        user_stats = []
-        for f in user:
-            s = bob.learn.em.GMMStats(ubm.shape[0], ubm.shape[1])
-            ubm.transform(f, s)
-            user_stats.append(s)
-        stats.append(user_stats)
-
-    relevance_factor = 4
-    subspace_dimension_of_u = 1
-
-    isvbase = bob.learn.em.ISVBase(ubm, subspace_dimension_of_u)
-    trainer = bob.learn.em.ISVTrainer(relevance_factor)
-    # trainer.rng = bob.core.random.mt19937(int(self.init_seed))
-    bob.learn.em.train(trainer, isvbase, stats, max_iterations=50)
-
-    return isvbase
-
-
 # GENERATING DATA
 iris_data = load_iris()
 X = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3]))
@@ -73,10 +42,8 @@ ubm.variances = np.array(
 )
 
 ubm.weights = np.array([0.36, 0.36, 0.28])
-# .fit(X)
 
-gmm_stats = [ubm.transform(x[np.newaxis]) for x in X]
-jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V, em_iterations=50)
+jfa_machine = bob.learn.em.JFAMachine(r_U, r_V, ubm=ubm, em_iterations=50)
 
 # Initializing with old bob initialization
 jfa_machine.U = np.array(
@@ -89,13 +56,7 @@ jfa_machine.Y = np.array(
 jfa_machine.D = np.array(
     [0.732467, 0.281321, 0.543212, -0.512974, 1.04108, 0.835224]
 )
-jfa_machine = jfa_machine.fit(gmm_stats, y)
-
-
-# .fit(gmm_stats, y)
-
-# gmm_stats = [ubm.transform(x) for x in [setosa, versicolor, virginica]]
-# jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V).fit(gmm_stats, [0, 1, 2])
+jfa_machine = jfa_machine.fit(X, y)
 
 
 # Variability direction U
-- 
GitLab