diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index f3560749635e9bf69b33104a275764f7711c9451..8a76d29d084e68abc9d3cc1dbe4f9bcb5eaecbb0 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -4,10 +4,12 @@ import logging +import dask import numpy as np from sklearn.base import BaseEstimator +from .gmm import GMMMachine from .linear_scoring import linear_scoring logger = logging.getLogger(__name__) @@ -57,9 +59,6 @@ class FactorAnalysisBase(BaseEstimator): Parameters ---------- - ubm: :py:class:`bob.learn.em.GMMMachine` - A trained UBM (Universal Background Model) - r_U: int Dimension of the subspace U @@ -75,17 +74,22 @@ class FactorAnalysisBase(BaseEstimator): seed: int Seed for the random number generator - + ubm: :py:class:`bob.learn.em.GMMMachine` + A trained UBM (Universal Background Model) or a parametrized + :py:class:`bob.learn.em.GMMMachine` to train the UBM with. If None, + `gmm_kwargs` are passed as parameters of a new + :py:class:`bob.learn.em.GMMMachine`. """ def __init__( self, - ubm, r_U, r_V=None, relevance_factor=4.0, em_iterations=10, seed=0, + ubm=None, + **gmm_kwargs, ): self.ubm = ubm self.em_iterations = em_iterations @@ -96,8 +100,6 @@ class FactorAnalysisBase(BaseEstimator): self.r_V = r_V self.relevance_factor = relevance_factor - # Initializing the state matrix - self.create_UVD() @property def feature_dimension(self): @@ -176,7 +178,7 @@ class FactorAnalysisBase(BaseEstimator): def initialize(self, X, y): """ - Accumulating 0th and 1st order statistics + Accumulating 0th and 1st order statistics. Trains the UBM if needed. Parameters ---------- @@ -195,13 +197,29 @@ class FactorAnalysisBase(BaseEstimator): """ + if self.ubm is None: + logger.info("FA: Creating a new GMMMachine.") + self.ubm = GMMMachine(**self.gmm_kwargs) + + # Train the UBM if not already trained + if self.ubm._means is None: + logger.info(f"FA: Training the UBM with {self.ubm}.") + self.ubm.fit(np.vstack(X)) # GMMMachine.fit takes non-labeled data + + logger.info("FA: Projection of training data on the UBM.") + ubm_projected_X = [dask.delayed(self.ubm.transform(xx)) for xx in X] + + # Initializing the state matrix + if not hasattr(self, "_U") or not hasattr(self, "_D"): + self.create_UVD() + # Accumulating 0th and 1st order statistics # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68 # 0th order stats - n_acc = self._sum_n_statistics(X, y) + n_acc = self._sum_n_statistics(ubm_projected_X, y) # 1st order stats - f_acc = self._sum_f_statistics(X, y) + f_acc = self._sum_f_statistics(ubm_projected_X, y) return n_acc, f_acc @@ -1143,9 +1161,6 @@ class ISVMachine(FactorAnalysisBase): Parameters ---------- - ubm: :py:class:`bob.learn.em.GMMMachine` - A trained UBM (Universal Background Model) - r_U: int Dimension of the subspace U @@ -1158,20 +1173,43 @@ class ISVMachine(FactorAnalysisBase): seed: int Seed for the random number generator + ubm: :py:class:`bob.learn.em.GMMMachine` or None + A trained UBM (Universal Background Model). If None, the UBM is trained with + a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `gmm_kwargs` + as parameters. + """ def __init__( - self, ubm, r_U, em_iterations=10, relevance_factor=4.0, seed=0 + self, + r_U, + em_iterations=10, + relevance_factor=4.0, + seed=0, + ubm=None, + **gmm_kwargs, ): super(ISVMachine, self).__init__( - ubm, r_U=r_U, relevance_factor=relevance_factor, em_iterations=em_iterations, seed=seed, + ubm=ubm, + **gmm_kwargs, ) def initialize(self, X, y): + """Initializes the ISV parameters and trains a UBM with `X` if needed. + + If no UBM has been defined on init, it is trained with a new GMMMachine. + + Parameters + ---------- + X: np.ndarray of shape(n_clients, n_samples, n_features) + Input data for each client. + y: np.ndarray of shape(n_clients,) + Client labels. + """ return super(ISVMachine, self).initialize(X, y) def e_step(self, X, y, n_acc, f_acc): @@ -1229,11 +1267,7 @@ class ISVMachine(FactorAnalysisBase): """ - # In case those variables are already set - if not hasattr(self, "_U") or not hasattr(self, "_D"): - self.create_UVD() - - y = y.tolist() if not isinstance(y, list) else y + y = np.array(y).tolist() if not isinstance(y, list) else y # TODO: Point of MAP-REDUCE n_acc, f_acc = self.initialize(X, y) @@ -1245,6 +1279,10 @@ class ISVMachine(FactorAnalysisBase): return self + def transform(self, X): + ubm_projected_X = self.ubm.transform(X) + return self.estimate_ux(ubm_projected_X) + def enroll(self, X, iterations=1): """ Enrolls a new client @@ -1754,7 +1792,7 @@ class JFAMachine(FactorAnalysisBase): ): self.create_UVD() - y = y.tolist() if not isinstance(y, list) else y + y = np.array(y).tolist() if not isinstance(y, list) else y # TODO: Point of MAP-REDUCE n_acc, f_acc = self.initialize(X, y)