From deb0d23c6b74b52cf054704cb071e57aca4ed2ae Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 31 Mar 2022 13:29:40 +0200
Subject: [PATCH] Created enrollment and scoring functions that accept arrays
 as input

---
 bob/learn/em/factor_analysis.py       |  89 ++++-
 bob/learn/em/test/test_jfa.py         |  17 +
 bob/learn/em/test/test_jfa_trainer.py | 160 +++++++-
 doc/guide.rst                         | 552 +++++++-------------------
 doc/index.rst                         |   1 -
 doc/plot/plot_MAP.py                  |  58 +++
 doc/plot/plot_ML.py                   |  38 +-
 7 files changed, 478 insertions(+), 437 deletions(-)
 create mode 100644 doc/plot/plot_MAP.py

diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index 318dbba..275a8af 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -2,6 +2,7 @@
 # @author: Tiago de Freitas Pereira
 
 
+from ast import Return
 import logging
 
 import numpy as np
@@ -1105,6 +1106,27 @@ class FactorAnalysisBase(BaseEstimator):
 
         return fn_x.flatten()
 
+    def score_with_array(self, model, data):
+        """
+        Computes the ISV score using a numpy array as input
+
+        Parameters
+        ----------
+        latent_z : numpy.ndarray
+            Latent representation of the client (E[z_i])
+
+        data : list of :py:class:`bob.learn.em.GMMStats`
+            List of statistics to be scored
+
+        Returns
+        -------
+        score : float
+            The linear scored
+
+        """
+
+        return self.score(model, self.ubm.acc_statistics(data))
+
 
 class ISVMachine(FactorAnalysisBase):
     """
@@ -1210,11 +1232,11 @@ class ISVMachine(FactorAnalysisBase):
 
         y = y.tolist() if not isinstance(y, list) else y
 
-        # TODO: Point of parallelism
+        # TODO: Point of MAP-REDUCE
         n_acc, f_acc = self.initialize(X, y)
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i)
-            # TODO: Point of parallelism
+            # TODO: Point of MAP-REDUCE
             acc_U_A1, acc_U_A2 = self.e_step(X, y, n_acc, f_acc)
             self.m_step(acc_U_A1, acc_U_A2)
 
@@ -1223,6 +1245,8 @@ class ISVMachine(FactorAnalysisBase):
     def enroll(self, X, iterations=1):
         """
         Enrolls a new client
+        In ISV, the enrolment is defined as: :math:`m + Dz` with the latent variables `z`
+        representing the enrolled model.
 
         Parameters
         ----------
@@ -1255,6 +1279,26 @@ class ISVMachine(FactorAnalysisBase):
 
         return latent_z
 
+    def enroll_with_array(self, X, iterations=1):
+        """
+        Enrolls a new client using a numpy array as input
+
+        Parameters
+        ----------
+        X : array
+            features to be enrolled
+
+        iterations : int
+            Number of iterations to perform
+
+        Returns
+        -------
+        self : object
+            z
+
+        """
+        return self.enroll([self.ubm.acc_statistics(X)], iterations)
+
     def score(self, latent_z, data):
         """
         Computes the ISV score
@@ -1621,7 +1665,9 @@ class JFAMachine(FactorAnalysisBase):
 
     def enroll(self, X, iterations=1):
         """
-        Enrolls a new client
+        Enrolls a new client.
+        In JFA the enrolment is defined as: :math:`m + Vy + Dz` with the latent variables `y` and `z`
+        representing the enrolled model.
 
         Parameters
         ----------
@@ -1633,8 +1679,8 @@ class JFAMachine(FactorAnalysisBase):
 
         Returns
         -------
-        self : object
-            z, y
+        self : array
+            z, y latent variables
 
         """
         # We have only one class for enrollment
@@ -1656,7 +1702,28 @@ class JFAMachine(FactorAnalysisBase):
                 X, y, latent_x, latent_y, latent_z, n_acc, f_acc
             )
 
-        return latent_y, latent_z
+        # The latent variables are wrapped in to 2axis arrays
+        return latent_y[0], latent_z[0]
+
+    def enroll_with_array(self, X, iterations=1):
+        """
+        Enrolls a new client using a numpy array as input
+
+        Parameters
+        ----------
+        X : array
+            features to be enrolled
+
+        iterations : int
+            Number of iterations to perform
+
+        Returns
+        -------
+        self : object
+            z
+
+        """
+        return self.enroll([self.ubm.acc_statistics(X)], iterations)
 
     def fit(self, X, y):
         """
@@ -1686,13 +1753,13 @@ class JFAMachine(FactorAnalysisBase):
 
         y = y.tolist() if not isinstance(y, list) else y
 
-        # TODO: Point of parallelism
+        # TODO: Point of MAP-REDUCE
         n_acc, f_acc = self.initialize(X, y)
 
         # Updating V
         for i in range(self.em_iterations):
             logger.info("V Training: Iteration %d", i)
-            # TODO: Point of parallelism
+            # TODO: Point of MAP-REDUCE
             acc_V_A1, acc_V_A2 = self.e_step_v(X, y, n_acc, f_acc)
             self.m_step_v(acc_V_A1, acc_V_A2)
         latent_y = self.finalize_v(X, y, n_acc, f_acc)
@@ -1700,7 +1767,7 @@ class JFAMachine(FactorAnalysisBase):
         # Updating U
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i)
-            # TODO: Point of parallelism
+            # TODO: Point of MAP-REDUCE
             acc_U_A1, acc_U_A2 = self.e_step_u(X, y, latent_y)
             self.m_step_u(acc_U_A1, acc_U_A2)
 
@@ -1709,7 +1776,7 @@ class JFAMachine(FactorAnalysisBase):
         # Updating D
         for i in range(self.em_iterations):
             logger.info("D Training: Iteration %d", i)
-            # TODO: Point of parallelism
+            # TODO: Point of MAP-REDUCE
             acc_D_A1, acc_D_A2 = self.e_step_d(
                 X, y, latent_x, latent_y, n_acc, f_acc
             )
@@ -1719,7 +1786,7 @@ class JFAMachine(FactorAnalysisBase):
 
     def score(self, model, data):
         """
-        Computes the ISV score
+        Computes the JFA score
 
         Parameters
         ----------
diff --git a/bob/learn/em/test/test_jfa.py b/bob/learn/em/test/test_jfa.py
index c96f59e..de4f761 100644
--- a/bob/learn/em/test/test_jfa.py
+++ b/bob/learn/em/test/test_jfa.py
@@ -55,6 +55,14 @@ def test_JFAMachine():
     score = m.score(model, gs)
     assert abs(score_ref - score) < eps
 
+    # Scoring with numpy array
+    np.random.seed(0)
+    X = np.random.normal(loc=0.0, scale=1.0, size=(50, 3))
+    score_ref = 2.028009315286946
+    score = m.score_with_array(model, X)
+
+    assert abs(score_ref - score) < eps
+
 
 def test_ISVMachine():
 
@@ -97,3 +105,12 @@ def test_ISVMachine():
     score_ref = -3.280498193082100
 
     assert abs(score_ref - score) < eps
+
+    # Scoring with numpy array
+    np.random.seed(0)
+    X = np.random.normal(loc=0.0, scale=1.0, size=(50, 3))
+    score_ref = -1.2343813195374242
+
+    score = isv_machine.score_with_array(latent_z, X)
+
+    assert abs(score_ref - score) < eps
diff --git a/bob/learn/em/test/test_jfa_trainer.py b/bob/learn/em/test/test_jfa_trainer.py
index b4b75b7..5d41c7c 100644
--- a/bob/learn/em/test/test_jfa_trainer.py
+++ b/bob/learn/em/test/test_jfa_trainer.py
@@ -213,6 +213,84 @@ def test_JFATrainAndEnrol():
     assert np.allclose(latent_z, z_ref, eps)
 
 
+def test_JFATrainAndEnrolWithNumpy():
+    # Train and enroll a JFAMachine
+
+    # Calls the train function
+    ubm = GMMMachine(2, 3)
+    ubm.means = UBM_MEAN.reshape((2, 3))
+    ubm.variances = UBM_VAR.reshape((2, 3))
+    it = JFAMachine(ubm, 2, 2, em_iterations=10)
+
+    it.U = copy.deepcopy(M_u)
+    it.V = copy.deepcopy(M_v)
+    it.D = copy.deepcopy(M_d)
+    it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
+
+    v_ref = np.array(
+        [
+            [0.245364911936476, 0.978133261775424],
+            [0.769646805052223, 0.940070736856596],
+            [0.310779202800089, 1.456332053893072],
+            [0.184760934399551, 2.265139705602147],
+            [0.701987784039800, 0.081632150899400],
+            [0.074344030229297, 1.090248340917255],
+        ],
+        "float64",
+    )
+    u_ref = np.array(
+        [
+            [0.049424652628448, 0.060480486336896],
+            [0.178104127464007, 1.884873813495153],
+            [1.204011484266777, 2.281351307871720],
+            [7.278512126426286, -0.390966087173334],
+            [-0.084424326581145, -0.081725474934414],
+            [4.042143689831097, -0.262576386580701],
+        ],
+        "float64",
+    )
+    d_ref = np.array(
+        [
+            9.648467e-18,
+            2.63720683155e-12,
+            2.11822157653706e-10,
+            9.1047243e-17,
+            1.41163442535567e-10,
+            3.30581e-19,
+        ],
+        "float64",
+    )
+
+    eps = 1e-10
+    assert np.allclose(it.V, v_ref, eps)
+    assert np.allclose(it.U, u_ref, eps)
+    assert np.allclose(it.D, d_ref, eps)
+
+    """
+    Calls the enroll function with arrays as input
+    """
+
+    np.random.seed(0)
+    X = np.random.normal(ubm.means[0], scale=0.5, size=(50, 3))
+    latent_y_ref = np.array([-0.13922039, 0.10686916])
+    latent_z_ref = np.array(
+        [
+            [
+                -1.37073043e-17,
+                1.15641870e-12,
+                -8.29922598e-10,
+                -4.17108194e-16,
+                -2.27107305e-09,
+                2.94293314e-18,
+            ]
+        ]
+    )
+
+    latent_y, latent_z = it.enroll_with_array(X)
+    assert np.allclose(latent_z, latent_z_ref, eps)
+    assert np.allclose(latent_y, latent_y_ref, eps)
+
+
 def test_ISVTrainAndEnrol():
     # Train and enroll an 'ISVMachine'
 
@@ -301,8 +379,86 @@ def test_ISVTrainAndEnrol():
 
     gse = [gse1, gse2]
 
-    latent_z = it.enroll(gse, 5)
-    assert np.allclose(latent_z, z_ref, eps)
+
+def test_ISVTrainAndEnrolWithNumpy():
+    # Train and enroll an 'ISVMachine'
+
+    eps = 1e-10
+    d_ref = np.array(
+        [
+            0.39601136,
+            0.07348469,
+            0.47712682,
+            0.44738127,
+            0.43179856,
+            0.45086029,
+        ],
+        "float64",
+    )
+    u_ref = np.array(
+        [
+            [0.855125642430777, 0.563104284748032],
+            [-0.325497865404680, 1.923598985291687],
+            [0.511575659503837, 1.964288663083095],
+            [9.330165761678115, 1.073623827995043],
+            [0.511099245664012, 0.278551249248978],
+            [5.065578541930268, 0.509565618051587],
+        ],
+        "float64",
+    )
+    z_ref = np.array(
+        [
+            -0.079315777443826,
+            0.092702428248543,
+            -0.342488761656616,
+            -0.059922635809136,
+            0.133539981073604,
+            0.213118695516570,
+        ],
+        "float64",
+    )
+
+    """
+    Calls the train function
+    """
+    ubm = GMMMachine(n_gaussians=2)
+    ubm.means = UBM_MEAN.reshape((2, 3))
+    ubm.variances = UBM_VAR.reshape((2, 3))
+
+    it = ISVMachine(
+        ubm,
+        r_U=2,
+        relevance_factor=4.0,
+        em_iterations=10,
+    )
+
+    it.U = copy.deepcopy(M_u)
+    it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y)
+
+    assert np.allclose(it.D, d_ref, eps)
+    assert np.allclose(it.U, u_ref, eps)
+
+    """
+    Calls the enroll function with arrays as input
+    """
+
+    np.random.seed(0)
+    X = np.random.normal(ubm.means[0], scale=0.5, size=(50, 3))
+    latent_z_ref = np.array(
+        [
+            [
+                0.01084525,
+                0.06039035,
+                -0.16920933,
+                -0.17321376,
+                -0.9648409,
+                0.44581105,
+            ]
+        ]
+    )
+
+    latent_z = it.enroll_with_array(X)
+    assert np.allclose(latent_z, latent_z_ref, eps)
 
 
 def test_JFATrainInitialize():
diff --git a/doc/guide.rst b/doc/guide.rst
index 82d0b95..51e308d 100644
--- a/doc/guide.rst
+++ b/doc/guide.rst
@@ -101,8 +101,8 @@ This statistical model is defined in the class
    :options: +NORMALIZE_WHITESPACE +SKIP
 
    >>> import bob.learn.em
-   >>> # Create a GMM with k=2 Gaussians with the dimensionality of 3
-   >>> gmm_machine = bob.learn.em.GMMMachine(2, 3)
+   >>> # Create a GMM with k=2 Gaussians
+   >>> gmm_machine = bob.learn.em.GMMMachine(n_gaussians=2)
 
 
 There are plenty of ways to estimate :math:`\Theta`; the next subsections
@@ -118,7 +118,7 @@ the parameters of a statistical model given observations by finding the
 :math:`\Theta` that maximizes :math:`P(x|\Theta)` for all :math:`x` in your
 dataset [9]_. This optimization is done by the **Expectation-Maximization**
 (EM) algorithm [8]_ and it is implemented by
-:py:class:`bob.learn.em.ML_GMMTrainer`.
+:py:class:`bob.learn.em.GMMMachine` by setting the keyword argument `trainer="ml"`.
 
 A very nice explanation of EM algorithm for the maximum likelihood estimation
 can be found in this
@@ -130,7 +130,7 @@ estimator.
 
 
 .. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
+   :options: +NORMALIZE_WHITESPACE
 
    >>> import bob.learn.em
    >>> import numpy
@@ -141,35 +141,22 @@ estimator.
    ...      [-7,7,-100],
    ...      [-5,5,-101]], dtype='float64')
    >>> # Create a kmeans model (machine) m with k=2 clusters
-   >>> # with a dimensionality equal to 3
-   >>> gmm_machine = bob.learn.em.GMMMachine(2, 3)
-   >>> # Using the MLE trainer to train the GMM:
-   >>> # True, True, True means update means/variances/weights at each
-   >>> # iteration
-   >>> gmm_trainer = bob.learn.em.ML_GMMTrainer(True, True, True)
-   >>> # Setting some means to start the training.
-   >>> # In practice, the output of kmeans is a good start for the MLE training
-   >>> gmm_machine.means = numpy.array(
-   ...     [[ -4.,   2.3,  -10.5],
-   ...      [  2.5, -4.5,   59. ]])
-   >>> max_iterations = 200
-   >>> convergence_threshold = 1e-5
+   >>> # and using the MLE trainer to train the GMM:
+   >>> # In this setup, kmeans is used to initialize the means, variances and weights of the gaussians
+   >>> gmm_machine = bob.learn.em.GMMMachine(n_gaussians=2, trainer="ml")
    >>> # Training
-   >>> bob.learn.em.train(gmm_trainer, gmm_machine, data,
-   ...                    max_iterations=max_iterations,
-   ...                    convergence_threshold=convergence_threshold)
+   >>> gmm_machine = gmm_machine.fit(data)
    >>> print(gmm_machine.means)
-   [[ -6.   6.  -100.5]
-    [  3.5 -3.5   99. ]]
+    [[   3.5   -3.5   99. ]
+     [  -6.     6.  -100.5]]
 
 Bellow follow an intuition of the GMM trained the maximum likelihood estimator
 using the Iris flower
 `dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_.
 
-..
-   TODO uncomment when implemented
-   .. plot:: plot/plot_ML.py
-      :include-source: False
+
+.. plot:: plot/plot_ML.py
+   :include-source: False
 
 
 Maximum a posteriori Estimator (MAP)
@@ -181,7 +168,7 @@ estimate that equals the mode of the posterior distribution by incorporating in
 its loss function a prior distribution [10]_. Commonly this prior distribution
 (the values of :math:`\Theta`) is estimated with MLE. This optimization is done
 by the **Expectation-Maximization** (EM) algorithm [8]_ and it is implemented
-by :py:class:`bob.learn.em.MAP_GMMTrainer`.
+by :py:class:`bob.learn.em.GMMMachine` by setting the keyword argument `trainer="map"`.
 
 A compact way to write relevance MAP adaptation is by using GMM supervector
 notation (this will be useful in the next subsections). The GMM supervector
@@ -195,7 +182,7 @@ Follow bellow an snippet on how to train a GMM using the MAP estimator.
 
 
 .. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
+   :options: +NORMALIZE_WHITESPACE
 
    >>> import bob.learn.em
    >>> import numpy
@@ -206,33 +193,31 @@ Follow bellow an snippet on how to train a GMM using the MAP estimator.
    ...      [-7,7,-100],
    ...      [-5,5,-101]], dtype='float64')
    >>> # Creating a fake prior
-   >>> prior_gmm = bob.learn.em.GMMMachine(2, 3)
-   >>> # Set some random means for the example
+   >>> prior_gmm = bob.learn.em.GMMMachine(2)
+   >>> # Set some random means/variances and weights for the example
    >>> prior_gmm.means = numpy.array(
    ...     [[ -4.,   2.3,  -10.5],
    ...      [  2.5, -4.5,   59. ]])
-   >>> # Creating the model for the adapted GMM
-   >>> adapted_gmm = bob.learn.em.GMMMachine(2, 3)
-   >>> # Creating the MAP trainer
-   >>> gmm_trainer = bob.learn.em.MAP_GMMTrainer(prior_gmm, relevance_factor=4)
-   >>>
-   >>> max_iterations = 200
-   >>> convergence_threshold = 1e-5
+   >>> prior_gmm.variances = numpy.array(
+   ...     [[ -0.1,   0.5,  -0.5],
+   ...      [  0.5, -0.5,   0.2 ]])
+   >>> prior_gmm.weights = numpy.array([ 0.8,   0.5])
+   >>> # Creating the model for the adapted GMM, and setting the `prior_gmm` as the source GMM
+   >>> # note that we have set `trainer="map"`, so we use the Maximum a posteriori estimator 
+   >>> adapted_gmm = bob.learn.em.GMMMachine(2, ubm=prior_gmm, trainer="map")
    >>> # Training
-   >>> bob.learn.em.train(gmm_trainer, adapted_gmm, data,
-   ...                    max_iterations=max_iterations,
-   ...                    convergence_threshold=convergence_threshold)
+   >>> adapted_gmm = adapted_gmm.fit(data)
    >>> print(adapted_gmm.means)
-    [[ -4.667   3.533 -40.5  ]
-     [  2.929  -4.071  76.143]]
+    [[ -4.      2.3   -10.5  ]
+     [  0.944  -1.833  36.889]]
 
 Bellow follow an intuition of the GMM trained with the MAP estimator using the
 Iris flower `dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_.
+It can be observed how the MAP means (the red triangles) around the center of each class
+from a prior GMM (the blue crosses).
 
-..
-   TODO uncomment when implemented
-   .. plot:: plot/plot_MAP.py
-      :include-source: False
+.. plot:: plot/plot_MAP.py
+   :include-source: False
 
 
 Session Variability Modeling with Gaussian Mixture Models
@@ -273,7 +258,7 @@ prior GMM.
 
 
 .. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
+   :options: +NORMALIZE_WHITESPACE
 
     >>> import bob.learn.em
     >>> import numpy
@@ -285,21 +270,13 @@ prior GMM.
     ...      [-0.3, -0.1, 0],
     ...      [1.2, 1.4, 1],
     ...      [0.8, 1., 1]], dtype='float64')
-    >>> # Creating a fake prior with 2 Gaussians of dimension 3
-    >>> prior_gmm = bob.learn.em.GMMMachine(2, 3)
-    >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)),
-    ...                                 numpy.random.normal(1, 0.5, (1, 3))))
-    >>> # All nice and round diagonal covariance
-    >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5
-    >>> prior_gmm.weights = numpy.array([0.3, 0.7])
+    >>> # Training a GMM with 2 Gaussians of dimension 3
+    >>> prior_gmm = bob.learn.em.GMMMachine(2).fit(data)    
     >>> # Creating the container
-    >>> gmm_stats_container = bob.learn.em.GMMStats(2, 3)
-    >>> for d in data:
-    ...    prior_gmm.acc_statistics(d, gmm_stats_container)
-    >>>
+    >>> gmm_stats = prior_gmm.acc_statistics(data)    
     >>> # Printing the responsibilities
-    >>> print(gmm_stats_container.n/gmm_stats_container.t)
-     [0.429  0.571]
+    >>> print(gmm_stats.n/gmm_stats.t)
+     [0.6  0.4]
 
 
 Inter-Session Variability
@@ -307,80 +284,81 @@ Inter-Session Variability
 .. _isv:
 
 Inter-Session Variability (ISV) modeling [3]_ [2]_ is a session variability
-modeling technique built on top of the Gaussian mixture modeling approach. It
-hypothesizes that within-class variations are embedded in a linear subspace in
-the GMM means subspace and these variations can be suppressed by an offset w.r.t
-each mean during the MAP adaptation.
+modeling technique built on top of the Gaussian mixture modeling approach.
+It hypothesizes that within-class variations are embedded in a linear subspace in
+the GMM means subspace, and these variations can be suppressed by an offset w.r.t each mean during the MAP adaptation.
 
-In this generative model each sample is assumed to have been generated by a GMM
-mean supervector with the following shape:
+
+In this generative model, each sample is assumed to have been generated by a GMM mean supervector with the following shape:
 :math:`\mu_{i, j} = m + Ux_{i, j} + D_z{i}`, where :math:`m` is our prior,
 :math:`Ux_{i, j}` is the session offset that we want to suppress and
 :math:`D_z{i}` is the class offset (with all session effects suppressed).
 
-All possible sources of session variations is embedded in this matrix
-:math:`U`. Follow bellow an intuition of what is modeled with :math:`U` in the
+It is hypothesized that all possible sources of session variations are embedded in this matrix
+:math:`U`. Follow below an intuition of what is modeled with :math:`U` in the
 Iris flower `dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_.
 The arrows :math:`U_{1}`, :math:`U_{2}` and :math:`U_{3}` are the directions of
-the within class variations, with respect to each Gaussian component, that will
+the within-class variations, with respect to each Gaussian component, that will
 be suppressed a posteriori.
 
-..
-   TODO uncomment when implemented
-   .. plot:: plot/plot_ISV.py
-      :include-source: False
+
+
+.. plot:: plot/plot_ISV.py
+   :include-source: False
 
 
 The ISV statistical model is stored in this container
-:py:class:`bob.learn.em.ISVBase` and the training is performed by
-:py:class:`bob.learn.em.ISVTrainer`. The snippet bellow shows how to train a
-Intersession variability modeling.
+:py:class:`bob.learn.em.ISVMachine`.
+The snippet bellow shows how to:
+
+  - Train a Intersession variability modeling.
+  - Enroll a subject with such a model.
+  - Compute score with such a model.
 
 
 .. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
+   :options: +NORMALIZE_WHITESPACE
+
+   >>> import bob.learn.em
+   >>> import numpy as np
+
+   >>> np.random.seed(10)
+
+   >>> # Generating some fake data
+   >>> data_class1 = np.random.normal(0, 0.5, (10, 3))
+   >>> data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
+   >>> X = np.vstack((data_class1, data_class2))
+   >>> y = np.hstack((np.zeros(10, dtype=int), np.ones(10, dtype=int)))
+   >>> # Training an UBM with 2 gaussians
+   >>> ubm = bob.learn.em.GMMMachine(2).fit(X)
+
+   >>> # The input the the ISV Training is the statistics of the GMM
+   >>> # Here we are creating a GMMStats for each datapoints, which is NOT usual,
+   >>> # but it is done for testing purposes
+   >>> gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X]
+
+   >>> # Finally doing the ISV training with U subspace with dimension of 2
+   >>> isv_machine = bob.learn.em.ISVMachine(ubm, r_U=2).fit(gmm_stats, y)
+   >>> print(isv_machine.U)
+     [[-0.079 -0.011]
+     [ 0.078  0.039]
+     [ 0.129  0.018]
+     [ 0.175  0.254]
+     [ 0.019  0.027]
+     [-0.132 -0.191]]
+
+   >>> # Enrolling a subject
+   >>> enroll_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
+   >>> model = isv_machine.enroll_with_array(enroll_data)
+   >>> print(model)
+     [[ 0.54   0.246  0.505  1.617 -0.791  0.746]]
+   
+   >>> # Probing
+   >>> probe_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
+   >>> score = isv_machine.score_with_array(model, probe_data)
+   >>> print(score)
+     [2.754]
 
-    >>> import bob.learn.em
-    >>> import numpy
-    >>> numpy.random.seed(10)
-    >>>
-    >>> # Generating some fake data
-    >>> data_class1 = numpy.random.normal(0, 0.5, (10, 3))
-    >>> data_class2 = numpy.random.normal(-0.2, 0.2, (10, 3))
-    >>> data = [data_class1, data_class2]
-
-    >>> # Creating a fake prior with 2 gaussians of dimension 3
-    >>> prior_gmm = bob.learn.em.GMMMachine(2, 3)
-    >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)),
-    ...                                 numpy.random.normal(1, 0.5, (1, 3))))
-    >>> # All nice and round diagonal covariance
-    >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5
-    >>> prior_gmm.weights = numpy.array([0.3, 0.7])
-    >>> # The input the the ISV Training is the statistics of the GMM
-    >>> gmm_stats_per_class = []
-    >>> for d in data:
-    ...   stats = []
-    ...   for i in d:
-    ...     gmm_stats_container = bob.learn.em.GMMStats(2, 3)
-    ...     prior_gmm.acc_statistics(i, gmm_stats_container)
-    ...     stats.append(gmm_stats_container)
-    ...   gmm_stats_per_class.append(stats)
-
-    >>> # Finally doing the ISV training
-    >>> subspace_dimension_of_u = 2
-    >>> relevance_factor = 4
-    >>> isvbase = bob.learn.em.ISVBase(prior_gmm, subspace_dimension_of_u)
-    >>> trainer = bob.learn.em.ISVTrainer(relevance_factor)
-    >>> bob.learn.em.train(trainer, isvbase, gmm_stats_per_class,
-    ...                    max_iterations=50)
-    >>> # Printing the session offset w.r.t each Gaussian component
-    >>> print(isvbase.u)
-      [[-0.01  -0.027]
-      [-0.002 -0.004]
-      [ 0.028  0.074]
-      [ 0.012  0.03 ]
-      [ 0.033  0.085]
-      [ 0.046  0.12 ]]
 
 
 Joint Factor Analysis
@@ -401,308 +379,59 @@ between class variations with respect to each Gaussian component that will be
 added a posteriori.
 
 
-..
-   TODO uncomment when implemented
-   .. plot:: plot/plot_JFA.py
-      :include-source: False
+.. plot:: plot/plot_JFA.py
+   :include-source: False
 
 The JFA statistical model is stored in this container
-:py:class:`bob.learn.em.JFABase` and the training is performed by
-:py:class:`bob.learn.em.JFATrainer`. The snippet bellow shows how to train a
-Intersession variability modeling.
-
-.. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
-
-    >>> import bob.learn.em
-    >>> import numpy
-    >>> numpy.random.seed(10)
-    >>>
-    >>> # Generating some fake data
-    >>> data_class1 = numpy.random.normal(0, 0.5, (10, 3))
-    >>> data_class2 = numpy.random.normal(-0.2, 0.2, (10, 3))
-    >>> data = [data_class1, data_class2]
-
-    >>> # Creating a fake prior with 2 Gaussians of dimension 3
-    >>> prior_gmm = bob.learn.em.GMMMachine(2, 3)
-    >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)),
-    ...                                 numpy.random.normal(1, 0.5, (1, 3))))
-    >>> # All nice and round diagonal covariance
-    >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5
-    >>> prior_gmm.weights = numpy.array([0.3, 0.7])
-    >>>
-    >>> # The input the the JFA Training is the statistics of the GMM
-    >>> gmm_stats_per_class = []
-    >>> for d in data:
-    ...   stats = []
-    ...   for i in d:
-    ...     gmm_stats_container = bob.learn.em.GMMStats(2, 3)
-    ...     prior_gmm.acc_statistics(i, gmm_stats_container)
-    ...     stats.append(gmm_stats_container)
-    ...   gmm_stats_per_class.append(stats)
-    >>>
-    >>> # Finally doing the JFA training
-    >>> subspace_dimension_of_u = 2
-    >>> subspace_dimension_of_v = 2
-    >>> relevance_factor = 4
-    >>> jfabase = bob.learn.em.JFABase(prior_gmm, subspace_dimension_of_u,
-    ...                                subspace_dimension_of_v)
-    >>> trainer = bob.learn.em.JFATrainer()
-    >>> bob.learn.em.train_jfa(trainer, jfabase, gmm_stats_per_class,
-    ...                        max_iterations=50)
-
-    >>> # Printing the session offset w.r.t each Gaussian component
-    >>> print(jfabase.v)
-     [[ 0.003 -0.006]
-      [ 0.041 -0.084]
-      [-0.261  0.53 ]
-      [-0.252  0.51 ]
-      [-0.387  0.785]
-      [-0.36   0.73 ]]
-
-Total variability Modelling
-===========================
-.. _ivector:
-
-Total Variability (TV) modeling [4]_ is a front-end initially introduced for
-speaker recognition, which aims at describing samples by vectors of low
-dimensionality called ``i-vectors``. The model consists of a subspace :math:`T`
-and a residual diagonal covariance matrix :math:`\Sigma`, that are then used to
-extract i-vectors, and is built upon the GMM approach. In the supervector
-notation this modeling has the following shape: :math:`\mu = m + Tv`.
-
-Follow bellow an intuition of the data from the Iris flower
-`dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_, embedded in
-the iVector space.
-
-..
-   TODO uncomment when implemented
-   .. plot:: plot/plot_iVector.py
-      :include-source: False
-
-
-The iVector statistical model is stored in this container
-:py:class:`bob.learn.em.IVectorMachine` and the training is performed by
-:py:class:`bob.learn.em.IVectorTrainer`. The snippet bellow shows how to train
-a Total variability modeling.
-
-.. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
-
-    >>> import bob.learn.em
-    >>> import numpy
-    >>> numpy.random.seed(10)
-    >>>
-    >>> # Generating some fake data
-    >>> data_class1 = numpy.random.normal(0, 0.5, (10, 3))
-    >>> data_class2 = numpy.random.normal(-0.2, 0.2, (10, 3))
-    >>> data = [data_class1, data_class2]
-    >>>
-    >>> # Creating a fake prior with 2 gaussians of dimension 3
-    >>> prior_gmm = bob.learn.em.GMMMachine(2, 3)
-    >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)),
-    ...                                 numpy.random.normal(1, 0.5, (1, 3))))
-    >>> # All nice and round diagonal covariance
-    >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5
-    >>> prior_gmm.weights = numpy.array([0.3, 0.7])
-    >>>
-    >>> # The input the the TV Training is the statistics of the GMM
-    >>> gmm_stats_per_class = []
-    >>> for d in data:
-    ...     for i in d:
-    ...       gmm_stats_container = bob.learn.em.GMMStats(2, 3)
-    ...       prior_gmm.acc_statistics(i, gmm_stats_container)
-    ...       gmm_stats_per_class.append(gmm_stats_container)
-    >>>
-    >>> # Finally doing the TV training
-    >>> subspace_dimension_of_t = 2
-    >>>
-    >>> ivector_trainer = bob.learn.em.IVectorTrainer(update_sigma=True)
-    >>> ivector_machine = bob.learn.em.IVectorMachine(
-    ...     prior_gmm, subspace_dimension_of_t, 10e-5)
-    >>> # train IVector model
-    >>> bob.learn.em.train(ivector_trainer, ivector_machine,
-    ...                    gmm_stats_per_class, 500)
-    >>>
-    >>> # Printing the session offset w.r.t each Gaussian component
-    >>> print(ivector_machine.t)
-     [[ 0.11  -0.203]
-      [-0.124  0.014]
-      [ 0.296  0.674]
-      [ 0.447  0.174]
-      [ 0.425  0.583]
-      [ 0.394  0.794]]
-
-Linear Scoring
-==============
-.. _linearscoring:
-
-In :ref:`MAP <map>` adaptation, :ref:`ISV <isv>` and :ref:`JFA <jfa>` a
-traditional way to do scoring is via the log-likelihood ratio between the
-adapted model and the prior as the following:
-
-.. math::
-   score = ln(P(x | \Theta)) -  ln(P(x | \Theta_{prior})),
+:py:class:`bob.learn.em.JFAMachine`. The snippet bellow shows how to train a
+such session variability model.
 
+  - Train a JFA model.
+  - Enroll a subject with such a model.
+  - Compute score with such a model.
 
-(with :math:`\Theta` varying for each approach).
-
-A simplification proposed by [Glembek2009]_, called linear scoring,
-approximate this ratio using a first order Taylor series as the following:
-
-.. math::
-   score = \frac{\mu - \mu_{prior}}{\sigma_{prior}} f * (\mu_{prior} + U_x),
-
-where :math:`\mu` is the the GMM mean supervector (of the prior and the adapted
-model), :math:`\sigma` is the variance, supervector, :math:`f` is the first
-order GMM statistics (:py:class:`bob.learn.em.GMMStats.sum_px`) and
-:math:`U_x`, is possible channel offset (:ref:`ISV <isv>`).
-
-This scoring technique is implemented in :py:func:`bob.learn.em.linear_scoring`.
-The snippet bellow shows how to compute scores using this approximation.
 
 .. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
+   :options: +NORMALIZE_WHITESPACE
 
    >>> import bob.learn.em
-   >>> import numpy
-   >>> # Defining a fake prior
-   >>> prior_gmm = bob.learn.em.GMMMachine(3, 2)
-   >>> prior_gmm.means = numpy.array([[1, 1], [2, 2.1], [3, 3]])
-   >>> # Defining a fake prior
-   >>> adapted_gmm = bob.learn.em.GMMMachine(3,2)
-   >>> adapted_gmm.means = numpy.array([[1.5, 1.5], [2.5, 2.5], [2, 2]])
-   >>> # Defining an input
-   >>> input = numpy.array([[1.5, 1.5], [1.6, 1.6]])
-   >>> #Accumulating statistics of the GMM
-   >>> stats = bob.learn.em.GMMStats(3, 2)
-   >>> prior_gmm.acc_statistics(input, stats)
-   >>> score = bob.learn.em.linear_scoring(
-   ...     [adapted_gmm], prior_gmm, [stats], [],
-   ...     frame_length_normalisation=True)
+   >>> import numpy as np
+
+   >>> np.random.seed(10)
+
+   >>> # Generating some fake data
+   >>> data_class1 = np.random.normal(0, 0.5, (10, 3))
+   >>> data_class2 = np.random.normal(-0.2, 0.2, (10, 3))
+   >>> X = np.vstack((data_class1, data_class2))
+   >>> y = np.hstack((np.zeros(10, dtype=int), np.ones(10, dtype=int)))
+   >>> # Training an UBM with 2 gaussians
+   >>> ubm = bob.learn.em.GMMMachine(2).fit(X)
+
+   >>> # The input the the JFA Training is the statistics of the GMM
+   >>> # Here we are creating a GMMStats for each datapoints, which is NOT usual,
+   >>> # but it is done for testing purposes
+   >>> gmm_stats = [ubm.acc_statistics(x[np.newaxis]) for x in X]
+
+   >>> # Finally doing the JFA training with U and V subspaces with dimension of 2
+   >>> jfa_machine = bob.learn.em.JFAMachine(ubm, r_U=2, r_V=2).fit(gmm_stats, y)
+   >>> print(jfa_machine.U)      
+     [[-0.069 -0.029]
+     [ 0.079  0.039]
+     [ 0.123  0.042]
+     [ 0.17   0.255]
+     [ 0.018  0.027]
+     [-0.128 -0.192]]
+
+   >>> enroll_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
+   >>> model = jfa_machine.enroll_with_array(enroll_data)
+   >>> print(model)
+     (array([0.634, 0.165]), array([ 0.,  0.,  0.,  0., -0.,  0.]))
+
+   >>> probe_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]])
+   >>> score = jfa_machine.score_with_array(model, probe_data)
    >>> print(score)
-    [[0.254]]
-
-
-Probabilistic Linear Discriminant Analysis (PLDA)
--------------------------------------------------
-
-Probabilistic Linear Discriminant Analysis [5]_ is a probabilistic model that
-incorporates components describing both between-class and within-class
-variations. Given a mean :math:`\mu`, between-class and within-class subspaces
-:math:`F` and :math:`G` and residual noise :math:`\epsilon` with zero mean and
-diagonal covariance matrix :math:`\Sigma`, the model assumes that a sample
-:math:`x_{i,j}` is generated by the following process:
-
-.. math::
-
-   x_{i,j} = \mu + F h_{i} + G w_{i,j} + \epsilon_{i,j}
-
-
-An Expectation-Maximization algorithm can be used to learn the parameters of
-this model :math:`\mu`, :math:`F` :math:`G` and :math:`\Sigma`. As these
-parameters can be shared between classes, there is a specific container class
-for this purpose, which is :py:class:`bob.learn.em.PLDABase`. The process is
-described in detail in [6]_.
-
-Let us consider a training set of two classes, each with 3 samples of
-dimensionality 3.
-
-.. doctest::
-   :options: +NORMALIZE_WHITESPACE +SKIP
-
-   >>> data1 = numpy.array(
-   ...     [[3,-3,100],
-   ...      [4,-4,50],
-   ...      [40,-40,150]], dtype=numpy.float64)
-   >>> data2 = numpy.array(
-   ...     [[3,6,-50],
-   ...      [4,8,-100],
-   ...      [40,79,-800]], dtype=numpy.float64)
-   >>> data = [data1,data2]
-
-Learning a PLDA model can be performed by instantiating the class
-:py:class:`bob.learn.em.PLDATrainer`, and calling the
-:py:meth:`bob.learn.em.train` method.
-
-.. doctest::
-   :options: +SKIP
-
-   >>> # This creates a PLDABase container for input feature of dimensionality
-   >>> # 3 and with subspaces F and G of rank 1 and 2, respectively.
-   >>> pldabase = bob.learn.em.PLDABase(3,1,2)
-
-   >>> trainer = bob.learn.em.PLDATrainer()
-   >>> bob.learn.em.train(trainer, pldabase, data, max_iterations=10)
-
-Once trained, this PLDA model can be used to compute the log-likelihood of a
-set of samples given some hypothesis. For this purpose, a
-:py:class:`bob.learn.em.PLDAMachine` should be instantiated. Then, the
-log-likelihood that a set of samples share the same latent identity variable
-:math:`h_{i}` (i.e. the samples are coming from the same identity/class) is
-obtained by calling the
-:py:meth:`bob.learn.em.PLDAMachine.compute_log_likelihood()` method.
-
-.. doctest::
-   :options: +SKIP
-
-   >>> plda = bob.learn.em.PLDAMachine(pldabase)
-   >>> samples = numpy.array(
-   ...     [[3.5,-3.4,102],
-   ...      [4.5,-4.3,56]], dtype=numpy.float64)
-   >>> loglike = plda.compute_log_likelihood(samples)
-
-If separate models for different classes need to be enrolled, each of them with
-a set of enrollment samples, then, several instances of
-:py:class:`bob.learn.em.PLDAMachine` need to be created and enrolled using
-the :py:meth:`bob.learn.em.PLDATrainer.enroll()` method as follows.
-
-.. doctest::
-   :options: +SKIP
-
-   >>> plda1 = bob.learn.em.PLDAMachine(pldabase)
-   >>> samples1 = numpy.array(
-   ...     [[3.5,-3.4,102],
-   ...      [4.5,-4.3,56]], dtype=numpy.float64)
-   >>> trainer.enroll(plda1, samples1)
-   >>> plda2 = bob.learn.em.PLDAMachine(pldabase)
-   >>> samples2 = numpy.array(
-   ...     [[3.5,7,-49],
-   ...      [4.5,8.9,-99]], dtype=numpy.float64)
-   >>> trainer.enroll(plda2, samples2)
-
-Afterwards, the joint log-likelihood of the enrollment samples and of one or
-several test samples can be computed as previously described, and this
-separately for each model.
-
-.. doctest::
-   :options: +SKIP
-
-   >>> sample = numpy.array([3.2,-3.3,58], dtype=numpy.float64)
-   >>> l1 = plda1.compute_log_likelihood(sample)
-   >>> l2 = plda2.compute_log_likelihood(sample)
-
-In a verification scenario, there are two possible hypotheses:
-
-#. :math:`x_{test}` and :math:`x_{enroll}` share the same class.
-#. :math:`x_{test}` and :math:`x_{enroll}` are from different classes.
-
-Using the methods :py:meth:`bob.learn.em.PLDAMachine.log_likelihood_ratio` or
-its alias ``__call__`` function, the corresponding log-likelihood ratio will be
-computed, which is defined in more formal way by:
-:math:`s = \ln(P(x_{test},x_{enroll})) - \ln(P(x_{test})P(x_{enroll}))`
-
-.. doctest::
-   :options: +SKIP
-
-   >>> s1 = plda1(sample)
-   >>> s2 = plda2(sample)
-
-.. testcleanup:: *
+     [0.471]
 
-  import shutil
-  os.chdir(current_directory)
-  shutil.rmtree(temp_dir)
 
 
 
@@ -711,9 +440,6 @@ computed, which is defined in more formal way by:
 .. [1] http://dx.doi.org/10.1109/TASL.2006.881693
 .. [2] http://publications.idiap.ch/index.php/publications/show/2606
 .. [3] http://dx.doi.org/10.1016/j.csl.2007.05.003
-.. [4] http://dx.doi.org/10.1109/TASL.2010.2064307
-.. [5] http://dx.doi.org/10.1109/ICCV.2007.4409052
-.. [6] http://doi.ieeecomputersociety.org/10.1109/TPAMI.2013.38
 .. [7] http://en.wikipedia.org/wiki/K-means_clustering
 .. [8] http://en.wikipedia.org/wiki/Expectation-maximization_algorithm
 .. [9] http://en.wikipedia.org/wiki/Maximum_likelihood
diff --git a/doc/index.rst b/doc/index.rst
index 14a5f1d..2e35eef 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -51,7 +51,6 @@ References
 ..
    .. [Roweis1998] Roweis, Sam. "EM algorithms for PCA and SPCA." Advances in neural information processing systems (1998): 626-632.
 
-.. [Glembek2009] Glembek, Ondrej, et al. "Comparison of scoring methods used in speaker recognition with joint factor analysis." Acoustics, Speech and Signal Processing, 2009. ICASSP 2009. IEEE International Conference on. IEEE, 2009.
 
 
 Indices and tables
diff --git a/doc/plot/plot_MAP.py b/doc/plot/plot_MAP.py
new file mode 100644
index 0000000..f668c8c
--- /dev/null
+++ b/doc/plot/plot_MAP.py
@@ -0,0 +1,58 @@
+import matplotlib.pyplot as plt
+from sklearn.datasets import load_iris
+
+import bob.learn.em
+import numpy as np
+
+np.random.seed(10)
+
+iris_data = load_iris()
+data = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3]))
+setosa = data[iris_data.target == 0]
+versicolor = data[iris_data.target == 1]
+virginica = data[iris_data.target == 2]
+
+
+# Two clusters with
+mle_machine = bob.learn.em.GMMMachine(3)
+# Creating some fake means for the example
+mle_machine.means = np.array([[5, 3], [4, 2], [7, 3.0]])
+mle_machine.variances = np.array([[0.1, 0.5], [0.2, 0.2], [0.7, 0.5]])
+
+
+# Creating some random data centered in
+map_machine = bob.learn.em.GMMMachine(
+    3, trainer="map", ubm=mle_machine, map_relevance_factor=4
+).fit(data)
+
+
+figure, ax = plt.subplots()
+# plt.scatter(data[:, 0], data[:, 1], c="olivedrab", label="new data")
+plt.scatter(setosa[:, 0], setosa[:, 1], c="darkcyan", label="setosa")
+plt.scatter(
+    versicolor[:, 0], versicolor[:, 1], c="goldenrod", label="versicolor"
+)
+plt.scatter(virginica[:, 0], virginica[:, 1], c="dimgrey", label="virginica")
+plt.scatter(
+    mle_machine.means[:, 0],
+    mle_machine.means[:, 1],
+    c="blue",
+    marker="x",
+    label="prior centroids - mle",
+    s=60,
+)
+plt.scatter(
+    map_machine.means[:, 0],
+    map_machine.means[:, 1],
+    c="red",
+    marker="^",
+    label="adapted centroids - map",
+    s=60,
+)
+plt.legend()
+plt.xticks([], [])
+plt.yticks([], [])
+ax.set_xlabel("Sepal length")
+ax.set_ylabel("Petal width")
+plt.tight_layout()
+plt.show()
diff --git a/doc/plot/plot_ML.py b/doc/plot/plot_ML.py
index 32f6643..410abf0 100644
--- a/doc/plot/plot_ML.py
+++ b/doc/plot/plot_ML.py
@@ -1,7 +1,7 @@
 import logging
 
 import matplotlib.pyplot as plt
-import numpy
+import numpy as np
 
 from matplotlib.lines import Line2D
 from matplotlib.patches import Ellipse
@@ -13,7 +13,7 @@ logger = logging.getLogger("bob.learn.em")
 logger.setLevel("DEBUG")
 
 iris_data = load_iris()
-data = iris_data.data
+data = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3]))
 setosa = data[iris_data.target == 0]
 versicolor = data[iris_data.target == 1]
 virginica = data[iris_data.target == 2]
@@ -28,7 +28,6 @@ machine = GMMMachine(
 )
 
 # Initialize the means with known values (optional, skips kmeans)
-machine.means = numpy.array([[5, 3], [4, 2], [7, 3]], dtype=float)
 machine = machine.fit(data)
 
 
@@ -48,14 +47,33 @@ ax.scatter(
     s=60,
 )
 
+
+def draw_ellipse(position, covariance, ax=None, **kwargs):
+    """
+    Draw an ellipse with a given position and covariance
+    """
+    ax = ax or plt.gca()
+
+    # Convert covariance to principal axes
+    if covariance.shape == (2, 2):
+        U, s, Vt = np.linalg.svd(covariance)
+        angle = np.degrees(np.arctan2(U[1, 0], U[0, 0]))
+        width, height = 2 * np.sqrt(s)
+    else:
+        angle = 0
+        width, height = 2 * np.sqrt(covariance)
+
+    # Draw the Ellipse
+    for nsig in range(1, 4):
+        ax.add_patch(
+            Ellipse(position, nsig * width, nsig * height, angle, **kwargs)
+        )
+
+
 # Draw ellipses for covariance
-for mean, variance in zip(machine.means, machine.variances):
-    eigvals, eigvecs = numpy.linalg.eig(numpy.diag(variance))
-    axis = numpy.sqrt(eigvals) * numpy.sqrt(5.991)
-    angle = 180.0 * numpy.arctan(eigvecs[1][0] / eigvecs[1][1]) / numpy.pi
-    ax.add_patch(
-        Ellipse(mean, *axis, angle=angle, linewidth=1, fill=False, zorder=2)
-    )
+w_factor = 0.2 / np.max(machine.weights)
+for w, pos, covar in zip(machine.weights, machine.means, machine.variances):
+    draw_ellipse(pos, covar, alpha=w * w_factor)
 
 # Plot details (legend, axis labels)
 plt.legend(
-- 
GitLab