From 90d7585c12c0ef45b104be2573ae8bda0c267786 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 22 Apr 2022 13:58:25 +0200 Subject: [PATCH] [doc] fix JFA plot script --- doc/plot/plot_JFA.py | 43 ++----------------------------------------- 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/doc/plot/plot_JFA.py b/doc/plot/plot_JFA.py index 6cfa4bc..a56ab32 100644 --- a/doc/plot/plot_JFA.py +++ b/doc/plot/plot_JFA.py @@ -8,37 +8,6 @@ import bob.learn.em np.random.seed(2) # FIXING A SEED -def isv_train(features, ubm): - """ - Train U matrix - - **Parameters** - features: List of :py:class:`bob.learn.em.GMMStats` organized by class - - n_gaussians: UBM (:py:class:`bob.learn.em.GMMMachine`) - - """ - - stats = [] - for user in features: - user_stats = [] - for f in user: - s = bob.learn.em.GMMStats(ubm.shape[0], ubm.shape[1]) - ubm.transform(f, s) - user_stats.append(s) - stats.append(user_stats) - - relevance_factor = 4 - subspace_dimension_of_u = 1 - - isvbase = bob.learn.em.ISVBase(ubm, subspace_dimension_of_u) - trainer = bob.learn.em.ISVTrainer(relevance_factor) - # trainer.rng = bob.core.random.mt19937(int(self.init_seed)) - bob.learn.em.train(trainer, isvbase, stats, max_iterations=50) - - return isvbase - - # GENERATING DATA iris_data = load_iris() X = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3])) @@ -73,10 +42,8 @@ ubm.variances = np.array( ) ubm.weights = np.array([0.36, 0.36, 0.28]) -# .fit(X) -gmm_stats = [ubm.transform(x[np.newaxis]) for x in X] -jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V, em_iterations=50) +jfa_machine = bob.learn.em.JFAMachine(r_U, r_V, ubm=ubm, em_iterations=50) # Initializing with old bob initialization jfa_machine.U = np.array( @@ -89,13 +56,7 @@ jfa_machine.Y = np.array( jfa_machine.D = np.array( [0.732467, 0.281321, 0.543212, -0.512974, 1.04108, 0.835224] ) -jfa_machine = jfa_machine.fit(gmm_stats, y) - - -# .fit(gmm_stats, y) - -# gmm_stats = [ubm.transform(x) for x in [setosa, versicolor, virginica]] -# jfa_machine = bob.learn.em.JFAMachine(ubm, r_U, r_V).fit(gmm_stats, [0, 1, 2]) +jfa_machine = jfa_machine.fit(X, y) # Variability direction U -- GitLab