From eb4077e57cd174e359449f17ac406ed82509c3e1 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 20 Apr 2022 12:11:31 +0200
Subject: [PATCH] Refactor factor_analysis to accept n_classes

---
 bob/learn/em/factor_analysis.py           | 368 +++++++++++++++++-----
 bob/learn/em/gmm.py                       |  12 +-
 bob/learn/em/kmeans.py                    |  28 +-
 bob/learn/em/test/test_factor_analysis.py |  14 +-
 bob/learn/em/utils.py                     |  33 ++
 5 files changed, 343 insertions(+), 112 deletions(-)
 create mode 100644 bob/learn/em/utils.py

diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index cb5680a..09d158c 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -2,21 +2,37 @@
 # @author: Tiago de Freitas Pereira
 
 
+import functools
 import logging
+import operator
 
 import dask
 import numpy as np
 
+from dask.delayed import Delayed
 from sklearn.base import BaseEstimator
 from sklearn.utils.multiclass import unique_labels
 
 from .gmm import GMMMachine
-from .kmeans import check_and_persist_dask_input
 from .linear_scoring import linear_scoring
+from .utils import array_to_delayed_list, check_and_persist_dask_input
 
 logger = logging.getLogger(__name__)
 
 
+def is_input_delayed(X):
+    """
+    Check if the input is a list of dask delayed objects.
+    """
+    if isinstance(X, (list, tuple)):
+        return is_input_delayed(X[0])
+
+    if isinstance(X, Delayed):
+        return True
+    else:
+        return False
+
+
 def mult_along_axis(A, B, axis):
     """
     Magic function to multiply two arrays along a given axis.
@@ -106,7 +122,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         self.relevance_factor = relevance_factor
 
-        if ubm is not None:
+        if ubm is not None and ubm._means is not None:
             self.create_UVD()
 
     @property
@@ -208,20 +224,38 @@ class FactorAnalysisBase(BaseEstimator):
         if self.ubm is None:
             logger.info("FA: Creating a new GMMMachine and training it.")
             self.ubm = GMMMachine(**self.ubm_kwargs)
-            self.ubm.fit(X)  # GMMMachine.fit takes non-labeled data
+            self.ubm.fit(X)
+
+        if self.ubm._means is None:
+            self.ubm.fit(X)
 
         # Initializing the state matrix
         if not hasattr(self, "_U") or not hasattr(self, "_D"):
             self.create_UVD()
 
-    def initialize_using_stats(self, ubm_projected_X, y):
+    def initialize_using_stats(self, ubm_projected_X, y, n_classes):
         # Accumulating 0th and 1st order statistics
         # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68
-        # 0th order stats
-        n_acc = self._sum_n_statistics(ubm_projected_X, y)
 
-        # 1st order stats
-        f_acc = self._sum_f_statistics(ubm_projected_X, y)
+        if is_input_delayed(ubm_projected_X):
+            n_acc = [
+                dask.delayed(self._sum_n_statistics)(xx, yy, n_classes)
+                for xx, yy in zip(ubm_projected_X, y)
+            ]
+            n_acc = dask.compute(*n_acc)
+            n_acc = functools.reduce(operator.iadd, n_acc)
+
+            f_acc = [
+                dask.delayed(self._sum_f_statistics)(xx, yy, n_classes)
+                for xx, yy in zip(ubm_projected_X, y)
+            ]
+            f_acc = dask.compute(*f_acc)
+            f_acc = functools.reduce(operator.iadd, f_acc)
+        else:
+            # 0th order stats
+            n_acc = self._sum_n_statistics(ubm_projected_X, y, n_classes)
+            # 1st order stats
+            f_acc = self._sum_f_statistics(ubm_projected_X, y, n_classes)
 
         return n_acc, f_acc
 
@@ -258,7 +292,7 @@ class FactorAnalysisBase(BaseEstimator):
         else:
             self._V = 0
 
-    def _sum_n_statistics(self, X, y):
+    def _sum_n_statistics(self, X, y, n_classes):
         """
         Accumulates the 0th statistics for each client
 
@@ -270,6 +304,9 @@ class FactorAnalysisBase(BaseEstimator):
             y: list of ints
                 List of corresponding labels
 
+            n_classes: int
+                Number of classes
+
         Returns
         -------
             n_acc: array
@@ -277,9 +314,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         """
         # 0th order stats
-        n_acc = np.zeros(
-            (self.estimate_number_of_classes(y), self.ubm.n_gaussians)
-        )
+        n_acc = np.zeros((n_classes, self.ubm.n_gaussians))
 
         # Iterate for each client
         for x_i, y_i in zip(X, y):
@@ -288,7 +323,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return n_acc
 
-    def _sum_f_statistics(self, X, y):
+    def _sum_f_statistics(self, X, y, n_classes):
         """
         Accumulates the 1st order statistics for each client
 
@@ -299,6 +334,9 @@ class FactorAnalysisBase(BaseEstimator):
             y: list of ints
                List of corresponding labels
 
+            n_classes: int
+                Number of classes
+
         Returns
         -------
             f_acc: array
@@ -309,7 +347,7 @@ class FactorAnalysisBase(BaseEstimator):
         # 1st order stats
         f_acc = np.zeros(
             (
-                self.estimate_number_of_classes(y),
+                n_classes,
                 self.ubm.n_gaussians,
                 self.feature_dimension,
             )
@@ -410,7 +448,9 @@ class FactorAnalysisBase(BaseEstimator):
 
         return fn_x_ih
 
-    def update_x(self, X, y, UProd, latent_x, latent_y=None, latent_z=None):
+    def update_x(
+        self, X, y, n_classes, UProd, latent_x, latent_y=None, latent_z=None
+    ):
         """
         Computes a new math:`E[x]` See equation (29) in [McCool2013]_
 
@@ -424,6 +464,9 @@ class FactorAnalysisBase(BaseEstimator):
             y: list of ints
                 List of corresponding labels
 
+            n_classes: int
+                Number of classes
+
             UProd: array
                 Matrix containing U_c.T*inv(Sigma_c) @ U_c.T
 
@@ -445,7 +488,7 @@ class FactorAnalysisBase(BaseEstimator):
         # U.T @ inv(Sigma) - See Eq(37)
         UTinvSigma = self._U.T / self.variance_supervector
 
-        session_offsets = np.zeros(self.estimate_number_of_classes(y))
+        session_offsets = np.zeros(n_classes)
         # For each sample
         for x_i, y_i in zip(X, y):
             id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd)
@@ -789,7 +832,7 @@ class FactorAnalysisBase(BaseEstimator):
 
         return acc_D_A1, acc_D_A2
 
-    def initialize_XYZ(self, y):
+    def initialize_XYZ(self, y, n_classes):
         """
         Initialize E[x], E[y], E[z] state variables
 
@@ -819,14 +862,12 @@ class FactorAnalysisBase(BaseEstimator):
             )
 
         latent_y = (
-            np.zeros((self.estimate_number_of_classes(y), self.r_V))
+            np.zeros((n_classes, self.r_V))
             if self.r_V and self.r_V > 0
             else None
         )
 
-        latent_z = np.zeros(
-            (self.estimate_number_of_classes(y), self.supervector_dimension)
-        )
+        latent_z = np.zeros((n_classes, self.supervector_dimension))
 
         return latent_x, latent_y, latent_z
 
@@ -834,7 +875,9 @@ class FactorAnalysisBase(BaseEstimator):
      Estimating V and y
     """
 
-    def update_y(self, X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc):
+    def update_y(
+        self, X, y, n_classes, VProd, latent_x, latent_y, latent_z, n_acc, f_acc
+    ):
         """
         Computes a new math:`E[y]` See equation (30) in [McCool2013]_
 
@@ -847,6 +890,9 @@ class FactorAnalysisBase(BaseEstimator):
         y: list of ints
             List of corresponding labels
 
+        n_classes: int
+            Number of classes
+
         VProd: array
             Matrix representing V_c.T*inv(Sigma_c) @ V_c.T
 
@@ -870,7 +916,7 @@ class FactorAnalysisBase(BaseEstimator):
         VTinvSigma = self._V.T / self.variance_supervector
 
         # Loops over the labels
-        for label in range(self.estimate_number_of_classes(y)):
+        for label in range(n_classes):
             id_plus_v_prod_i = self._compute_id_plus_vprod_i(
                 n_acc[label], VProd
             )
@@ -1153,13 +1199,32 @@ class FactorAnalysisBase(BaseEstimator):
     def fit(self, X, y):
         input_is_dask, X = check_and_persist_dask_input(X)
         self.initialize(X)
+        y = np.squeeze(np.asarray(y))
         if input_is_dask:
-            stats = [dask.delayed(self.ubm.transform)(xx) for xx in X]
-            stats = dask.compute(*stats)
+            chunks = X.chunks[0]
+            # chunk the y array similar to X into a list of numpy arrays
+            i, new_y = 0, []
+            for chunk in chunks:
+                new_y.append(y[i : i + chunk])
+                i += chunk
+            y = new_y
+            X = array_to_delayed_list(X, input_is_dask)
+            stats = [
+                dask.delayed(self.ubm.transform_one_by_one)(xx).persist()
+                for xx in X
+            ]
+            # stats = dask.compute(*stats)
+            # stats = functools.reduce(operator.iadd, stats)
         else:
-            stats = [self.ubm.transform(xx) for xx in X]
+            stats = self.ubm.transform_one_by_one(X)
         del X  # we don't need to persist X anymore
-        return self.fit_using_stats(stats, y)
+
+        # if input_is_dask:
+        #     dask.compute(dask.delayed(self.fit_using_stats)(stats, y))
+        # else:
+        #     self.fit_using_stats(stats, y)
+        self.fit_using_stats(stats, y)
+        return self
 
 
 class ISVMachine(FactorAnalysisBase):
@@ -1213,18 +1278,30 @@ class ISVMachine(FactorAnalysisBase):
             **kwargs,
         )
 
-    def e_step(self, X, y, n_acc, f_acc):
+    def e_step(self, X, y, n_classes, n_acc, f_acc):
         """
         E-step of the EM algorithm
         """
         # self.initialize_XYZ(y)
         UProd = self._compute_uprod()
-        latent_x, _, latent_z = self.initialize_XYZ(y)
+        latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes)
         latent_y = None
 
-        latent_x = self.update_x(X, y, UProd, latent_x)
+        latent_x = self.update_x(
+            X=X,
+            y=y,
+            n_classes=n_classes,
+            UProd=UProd,
+            latent_x=latent_x,
+        )
         latent_z = self.update_z(
-            X, y, latent_x, latent_y, latent_z, n_acc, f_acc
+            X=X,
+            y=y,
+            latent_x=latent_x,
+            latent_y=latent_y,
+            latent_z=latent_z,
+            n_acc=n_acc,
+            f_acc=f_acc,
         )
         acc_U_A1, acc_U_A2 = self.compute_accumulators_U(
             X, y, UProd, latent_x, latent_y, latent_z
@@ -1232,7 +1309,7 @@ class ISVMachine(FactorAnalysisBase):
 
         return acc_U_A1, acc_U_A2
 
-    def m_step(self, acc_U_A1, acc_U_A2):
+    def m_step(self, acc_U_A1_acc_U_A2_list):
         """
         ISV M-step.
         This updates `U` matrix
@@ -1247,6 +1324,11 @@ class ISVMachine(FactorAnalysisBase):
                 Accumulated statistics for U_A2(n_gaussians* feature_dimension, r_U)
 
         """
+        acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list]
+        acc_U_A1 = functools.reduce(operator.iadd, acc_U_A1)
+
+        acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list]
+        acc_U_A2 = functools.reduce(operator.iadd, acc_U_A2)
 
         self.update_U(acc_U_A1, acc_U_A2)
 
@@ -1269,14 +1351,36 @@ class ISVMachine(FactorAnalysisBase):
         """
 
         y = np.asarray(y)
+        n_classes = self.estimate_number_of_classes(y)
+        n_acc, f_acc = self.initialize_using_stats(X, y, n_classes)
 
-        # TODO: Point of MAP-REDUCE
-        n_acc, f_acc = self.initialize_using_stats(X, y)
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
-            acc_U_A1, acc_U_A2 = self.e_step(X, y, n_acc, f_acc)
-            self.m_step(acc_U_A1, acc_U_A2)
+            if is_input_delayed(X):
+                acc_U_A1_acc_U_A2_list = [
+                    dask.delayed(self.e_step)(
+                        X=xx,
+                        y=yy,
+                        n_classes=n_classes,
+                        n_acc=n_acc,
+                        f_acc=f_acc,
+                    )
+                    for xx, yy in zip(X, y)
+                ]
+                delayed_em_step = dask.delayed(self.m_step)(
+                    acc_U_A1_acc_U_A2_list
+                )
+                dask.compute(delayed_em_step)
+            else:
+                acc_U_A1, acc_U_A2 = self.e_step(
+                    X=X,
+                    y=y,
+                    n_classes=n_classes,
+                    n_acc=n_acc,
+                    f_acc=f_acc,
+                )
+                self.m_step([(acc_U_A1, acc_U_A2)])
 
         return self
 
@@ -1284,7 +1388,7 @@ class ISVMachine(FactorAnalysisBase):
         ubm_projected_X = self.ubm.transform(X)
         return self.estimate_ux(ubm_projected_X)
 
-    def enroll(self, X, iterations=1):
+    def enroll_using_stats(self, X, iterations=1):
         """
         Enrolls a new client
         In ISV, the enrolment is defined as: :math:`m + Dz` with the latent variables `z`
@@ -1295,6 +1399,7 @@ class ISVMachine(FactorAnalysisBase):
         X : list of :py:class:`bob.learn.em.GMMStats`
             List of statistics to be enrolled
 
+
         iterations : int
             Number of iterations to perform
 
@@ -1306,22 +1411,36 @@ class ISVMachine(FactorAnalysisBase):
         """
         # We have only one class for enrollment
         y = list(np.zeros(len(X), dtype=np.int32))
-        n_acc = self._sum_n_statistics(X, y=y)
-        f_acc = self._sum_f_statistics(X, y=y)
+        n_acc = self._sum_n_statistics(X, y=y, n_classes=1)
+        f_acc = self._sum_f_statistics(X, y=y, n_classes=1)
 
         UProd = self._compute_uprod()
-        latent_x, _, latent_z = self.initialize_XYZ(y)
+        latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=1)
         latent_y = None
         for i in range(iterations):
             logger.info("Enrollment: Iteration %d", i + 1)
-            latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z)
+            latent_x = self.update_x(
+                X=X,
+                y=y,
+                n_classes=1,
+                UProd=UProd,
+                latent_x=latent_x,
+                latent_y=latent_y,
+                latent_z=latent_z,
+            )
             latent_z = self.update_z(
-                X, y, latent_x, latent_y, latent_z, n_acc, f_acc
+                X=X,
+                y=y,
+                latent_x=latent_x,
+                latent_y=latent_y,
+                latent_z=latent_z,
+                n_acc=n_acc,
+                f_acc=f_acc,
             )
 
         return latent_z
 
-    def enroll_with_array(self, X, iterations=1):
+    def enroll(self, X, iterations=1):
         """
         Enrolls a new client using a numpy array as input
 
@@ -1339,7 +1458,7 @@ class ISVMachine(FactorAnalysisBase):
             z
 
         """
-        return self.enroll([self.ubm.transform(X)], iterations)
+        return self.enroll_using_stats([self.ubm.transform(X)], iterations)
 
     def score_using_stats(self, latent_z, data):
         """
@@ -1431,7 +1550,7 @@ class JFAMachine(FactorAnalysisBase):
             **kwargs,
         )
 
-    def e_step_v(self, X, y, n_acc, f_acc):
+    def e_step_v(self, X, y, n_classes, n_acc, f_acc):
         """
         ISV E-step for the V matrix.
 
@@ -1444,6 +1563,9 @@ class JFAMachine(FactorAnalysisBase):
             y: list of int
                 List of labels
 
+            n_classes: int
+                Number of classes
+
             n_acc: array
                 Accumulated 0th-order statistics
 
@@ -1464,16 +1586,33 @@ class JFAMachine(FactorAnalysisBase):
 
         VProd = self._compute_vprod()
 
-        latent_x, latent_y, latent_z = self.initialize_XYZ(y)
+        latent_x, latent_y, latent_z = self.initialize_XYZ(
+            y, n_classes=n_classes
+        )
 
         # UPDATE Y, X AND FINALLY Z
 
         latent_y = self.update_y(
-            X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc
+            X=X,
+            y=y,
+            n_classes=n_classes,
+            VProd=VProd,
+            latent_x=latent_x,
+            latent_y=latent_y,
+            latent_z=latent_z,
+            n_acc=n_acc,
+            f_acc=f_acc,
         )
 
         acc_V_A1, acc_V_A2 = self.compute_accumulators_V(
-            X, y, VProd, n_acc, f_acc, latent_x, latent_y, latent_z
+            X=X,
+            y=y,
+            VProd=VProd,
+            n_acc=n_acc,
+            f_acc=f_acc,
+            latent_x=latent_x,
+            latent_y=latent_y,
+            latent_z=latent_z,
         )
 
         return acc_V_A1, acc_V_A2
@@ -1508,7 +1647,7 @@ class JFAMachine(FactorAnalysisBase):
             (self.ubm.n_gaussians * self.feature_dimension, self.r_V)
         )
 
-    def finalize_v(self, X, y, n_acc, f_acc):
+    def finalize_v(self, X, y, n_classes, n_acc, f_acc):
         """
         Compute for the last time `E[y]`
 
@@ -1521,6 +1660,9 @@ class JFAMachine(FactorAnalysisBase):
             y: list of int
                 List of labels
 
+            n_classes: int
+                Number of classes
+
             n_acc: array
                 Accumulated 0th-order statistics
 
@@ -1535,16 +1677,26 @@ class JFAMachine(FactorAnalysisBase):
         """
         VProd = self._compute_vprod()
 
-        latent_x, latent_y, latent_z = self.initialize_XYZ(y)
+        latent_x, latent_y, latent_z = self.initialize_XYZ(
+            y, n_classes=n_classes
+        )
 
         # UPDATE Y, X AND FINALLY Z
 
         latent_y = self.update_y(
-            X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc
+            X=X,
+            y=y,
+            n_classes=n_classes,
+            VProd=VProd,
+            latent_x=latent_x,
+            latent_y=latent_y,
+            latent_z=latent_z,
+            n_acc=n_acc,
+            f_acc=f_acc,
         )
         return latent_y
 
-    def e_step_u(self, X, y, latent_y):
+    def e_step_u(self, X, y, n_classes, latent_y):
         """
         ISV E-step for the U matrix.
 
@@ -1573,12 +1725,24 @@ class JFAMachine(FactorAnalysisBase):
         """
         # self.initialize_XYZ(y)
         UProd = self._compute_uprod()
-        latent_x, _, latent_z = self.initialize_XYZ(y)
+        latent_x, _, latent_z = self.initialize_XYZ(y, n_classes)
 
-        latent_x = self.update_x(X, y, UProd, latent_x, latent_y)
+        latent_x = self.update_x(
+            X=X,
+            y=y,
+            n_classes=n_classes,
+            UProd=UProd,
+            latent_x=latent_x,
+            latent_y=latent_y,
+        )
 
         acc_U_A1, acc_U_A2 = self.compute_accumulators_U(
-            X, y, UProd, latent_x, latent_y, latent_z
+            X=X,
+            y=y,
+            UProd=UProd,
+            latent_x=latent_x,
+            latent_y=latent_y,
+            latent_z=latent_z,
         )
 
         return acc_U_A1, acc_U_A2
@@ -1605,6 +1769,7 @@ class JFAMachine(FactorAnalysisBase):
         self,
         X,
         y,
+        n_classes,
         latent_y,
     ):
         """
@@ -1619,6 +1784,9 @@ class JFAMachine(FactorAnalysisBase):
             y: list of int
                 List of labels
 
+            n_classes: int
+                Number of classes
+
             latent_y: array
                 E[y] latent variable
 
@@ -1629,15 +1797,20 @@ class JFAMachine(FactorAnalysisBase):
         """
 
         UProd = self._compute_uprod()
-        latent_x, _, _ = self.initialize_XYZ(y)
+        latent_x, _, _ = self.initialize_XYZ(y, n_classes=n_classes)
 
         latent_x = self.update_x(
-            X, y, UProd, latent_x=latent_x, latent_y=latent_y
+            X=X,
+            y=y,
+            n_classes=n_classes,
+            UProd=UProd,
+            latent_x=latent_x,
+            latent_y=latent_y,
         )
 
         return latent_x
 
-    def e_step_d(self, X, y, latent_x, latent_y, n_acc, f_acc):
+    def e_step_d(self, X, y, n_classes, latent_x, latent_y, n_acc, f_acc):
         """
         ISV E-step for the U matrix.
 
@@ -1650,6 +1823,9 @@ class JFAMachine(FactorAnalysisBase):
             y: list of int
                 List of labels
 
+            n_classes: int
+                Number of classes
+
             latent_x: array
                 E(x) latent variable
 
@@ -1677,7 +1853,7 @@ class JFAMachine(FactorAnalysisBase):
 
         """
 
-        _, _, latent_z = self.initialize_XYZ(y)
+        _, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes)
 
         latent_z = self.update_z(
             X,
@@ -1712,7 +1888,7 @@ class JFAMachine(FactorAnalysisBase):
         """
         self._D = acc_D_A2 / acc_D_A1
 
-    def enroll(self, X, iterations=1):
+    def enroll_using_stats(self, X, iterations=1):
         """
         Enrolls a new client.
         In JFA the enrolment is defined as: :math:`m + Vy + Dz` with the latent variables `y` and `z`
@@ -1734,27 +1910,49 @@ class JFAMachine(FactorAnalysisBase):
         """
         # We have only one class for enrollment
         y = list(np.zeros(len(X), dtype=np.int32))
-        n_acc = self._sum_n_statistics(X, y=y)
-        f_acc = self._sum_f_statistics(X, y=y)
+        n_acc = self._sum_n_statistics(X, y=y, n_classes=1)
+        f_acc = self._sum_f_statistics(X, y=y, n_classes=1)
 
         UProd = self._compute_uprod()
         VProd = self._compute_vprod()
-        latent_x, latent_y, latent_z = self.initialize_XYZ(y)
+        latent_x, latent_y, latent_z = self.initialize_XYZ(y, n_classes=1)
 
         for i in range(iterations):
             logger.info("Enrollment: Iteration %d", i + 1)
             latent_y = self.update_y(
-                X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc
+                X=X,
+                y=y,
+                n_classes=1,
+                VProd=VProd,
+                latent_x=latent_x,
+                latent_y=latent_y,
+                latent_z=latent_z,
+                n_acc=n_acc,
+                f_acc=f_acc,
+            )
+            latent_x = self.update_x(
+                X=X,
+                y=y,
+                n_classes=1,
+                UProd=UProd,
+                latent_x=latent_x,
+                latent_y=latent_y,
+                latent_z=latent_z,
             )
-            latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z)
             latent_z = self.update_z(
-                X, y, latent_x, latent_y, latent_z, n_acc, f_acc
+                X=X,
+                y=y,
+                latent_x=latent_x,
+                latent_y=latent_y,
+                latent_z=latent_z,
+                n_acc=n_acc,
+                f_acc=f_acc,
             )
 
         # The latent variables are wrapped in to 2axis arrays
         return latent_y[0], latent_z[0]
 
-    def enroll_with_array(self, X, iterations=1):
+    def enroll(self, X, iterations=1):
         """
         Enrolls a new client using a numpy array as input
 
@@ -1772,7 +1970,7 @@ class JFAMachine(FactorAnalysisBase):
             z
 
         """
-        return self.enroll([self.ubm.transform(X)], iterations)
+        return self.enroll_using_stats([self.ubm.transform(X)], iterations)
 
     def fit_using_stats(self, X, y):
         """
@@ -1801,33 +1999,55 @@ class JFAMachine(FactorAnalysisBase):
             self.create_UVD()
 
         y = np.asarray(y)
+        n_classes = self.estimate_number_of_classes(y)
 
         # TODO: Point of MAP-REDUCE
-        n_acc, f_acc = self.initialize_using_stats(X, y)
+        n_acc, f_acc = self.initialize_using_stats(X, y, n_classes=n_classes)
 
         # Updating V
         for i in range(self.em_iterations):
             logger.info("V Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
-            acc_V_A1, acc_V_A2 = self.e_step_v(X, y, n_acc, f_acc)
+            acc_V_A1, acc_V_A2 = self.e_step_v(
+                X=X,
+                y=y,
+                n_classes=n_classes,
+                n_acc=n_acc,
+                f_acc=f_acc,
+            )
             self.m_step_v(acc_V_A1, acc_V_A2)
-        latent_y = self.finalize_v(X, y, n_acc, f_acc)
+        latent_y = self.finalize_v(
+            X=X, y=y, n_classes=n_classes, n_acc=n_acc, f_acc=f_acc
+        )
 
         # Updating U
         for i in range(self.em_iterations):
             logger.info("U Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
-            acc_U_A1, acc_U_A2 = self.e_step_u(X, y, latent_y)
+            acc_U_A1, acc_U_A2 = self.e_step_u(
+                X=X,
+                y=y,
+                n_classes=n_classes,
+                latent_y=latent_y,
+            )
             self.m_step_u(acc_U_A1, acc_U_A2)
 
-        latent_x = self.finalize_u(X, y, latent_y)
+        latent_x = self.finalize_u(
+            X=X, y=y, n_classes=n_classes, latent_y=latent_y
+        )
 
         # Updating D
         for i in range(self.em_iterations):
             logger.info("D Training: Iteration %d", i + 1)
             # TODO: Point of MAP-REDUCE
             acc_D_A1, acc_D_A2 = self.e_step_d(
-                X, y, latent_x, latent_y, n_acc, f_acc
+                X=X,
+                y=y,
+                n_classes=n_classes,
+                latent_x=latent_x,
+                latent_y=latent_y,
+                n_acc=n_acc,
+                f_acc=f_acc,
             )
             self.m_step_d(acc_D_A1, acc_D_A2)
 
diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py
index ce8581a..394a930 100644
--- a/bob/learn/em/gmm.py
+++ b/bob/learn/em/gmm.py
@@ -18,11 +18,8 @@ import numpy as np
 from h5py import File as HDF5File
 from sklearn.base import BaseEstimator
 
-from .kmeans import (
-    KMeansMachine,
-    array_to_delayed_list,
-    check_and_persist_dask_input,
-)
+from .kmeans import KMeansMachine
+from .utils import array_to_delayed_list, check_and_persist_dask_input
 
 logger = logging.getLogger(__name__)
 
@@ -853,10 +850,13 @@ class GMMMachine(BaseEstimator):
             )
         return self
 
-    def transform(self, X, **kwargs):
+    def transform(self, X):
         """Returns the statistics for `X`."""
         return e_step(data=X, machine=self)
 
+    def transform_one_by_one(self, X):
+        return [e_step(data=xx, machine=self) for xx in X]
+
     def _more_tags(self):
         return {
             "stateless": False,
diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py
index 8a17af3..61af76e 100644
--- a/bob/learn/em/kmeans.py
+++ b/bob/learn/em/kmeans.py
@@ -15,6 +15,8 @@ import scipy.spatial.distance
 from dask_ml.cluster.k_means import k_init
 from sklearn.base import BaseEstimator
 
+from .utils import array_to_delayed_list, check_and_persist_dask_input
+
 logger = logging.getLogger(__name__)
 
 
@@ -171,32 +173,6 @@ def reduce_indices_means_vars(stats):
     return variances, weights
 
 
-def check_and_persist_dask_input(data):
-    # check if input is a dask array. If so, persist and rebalance data
-    input_is_dask = False
-    if isinstance(data, da.Array):
-        data: da.Array = data.persist()
-        input_is_dask = True
-        # if there is a dask distributed client, rebalance data
-        try:
-            client = dask.distributed.Client.current()
-            client.rebalance()
-        except ValueError:
-            pass
-
-    else:
-        data = np.asarray(data)
-    return input_is_dask, data
-
-
-def array_to_delayed_list(data, input_is_dask):
-    # If input is a dask array, convert to delayed chunks
-    if input_is_dask:
-        data = data.to_delayed().ravel().tolist()
-        logger.debug(f"Got {len(data)} chunks.")
-    return data
-
-
 class KMeansMachine(BaseEstimator):
     """Stores the k-means clusters parameters (centroid of each cluster).
 
diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py
index 3b9ca21..08ae5ab 100644
--- a/bob/learn/em/test/test_factor_analysis.py
+++ b/bob/learn/em/test/test_factor_analysis.py
@@ -194,7 +194,7 @@ def test_JFATrainAndEnrol():
     gse2.sum_px = Fe[:, 1].reshape(2, 3)
 
     gse = [gse1, gse2]
-    latent_y, latent_z = it.enroll(gse, 5)
+    latent_y, latent_z = it.enroll_using_stats(gse, 5)
 
     y_ref = np.array([0.555991469319657, 0.002773650670010], "float64")
     z_ref = np.array(
@@ -302,7 +302,7 @@ def test_ISVTrainAndEnrol():
     gse2.sum_px = Fe[:, 1].reshape(2, 3)
 
     gse = [gse1, gse2]
-    latent_z = it.enroll(gse, 5)
+    latent_z = it.enroll_using_stats(gse, 5)
     np.testing.assert_allclose(latent_z, z_ref, rtol=eps, atol=1e-8)
 
 
@@ -320,13 +320,14 @@ def test_JFATrainInitialize():
     it = JFAMachine(2, 2, em_iterations=10, ubm=ubm)
     # first round
 
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
+    n_classes = it.estimate_number_of_classes(TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u1 = it.U
     v1 = it.V
     d1 = it.D
 
     # second round
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u2 = it.U
     v2 = it.V
     d2 = it.D
@@ -350,12 +351,13 @@ def test_ISVTrainInitialize():
     it = ISVMachine(2, em_iterations=10, ubm=ubm)
     # it.rng = rng
 
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
+    n_classes = it.estimate_number_of_classes(TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u1 = copy.deepcopy(it.U)
     d1 = copy.deepcopy(it.D)
 
     # second round
-    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y)
+    it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes)
     u2 = it.U
     d2 = it.D
 
diff --git a/bob/learn/em/utils.py b/bob/learn/em/utils.py
new file mode 100644
index 0000000..108e797
--- /dev/null
+++ b/bob/learn/em/utils.py
@@ -0,0 +1,33 @@
+import logging
+
+import dask
+import dask.array as da
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+
+def check_and_persist_dask_input(data):
+    # check if input is a dask array. If so, persist and rebalance data
+    input_is_dask = False
+    if isinstance(data, da.Array):
+        data: da.Array = data.persist()
+        input_is_dask = True
+        # if there is a dask distributed client, rebalance data
+        try:
+            client = dask.distributed.Client.current()
+            client.rebalance()
+        except ValueError:
+            pass
+
+    else:
+        data = np.asarray(data)
+    return input_is_dask, data
+
+
+def array_to_delayed_list(data, input_is_dask):
+    # If input is a dask array, convert to delayed chunks
+    if input_is_dask:
+        data = data.to_delayed().ravel().tolist()
+        logger.debug(f"Got {len(data)} chunks.")
+    return data
-- 
GitLab