From 88fe8201f726295f6ae9d0127e93532ba53b2a64 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Mon, 28 Mar 2022 18:30:52 +0200 Subject: [PATCH] Implemented JFA --- bob/learn/em/__init__.py | 4 +- bob/learn/em/factor_analysis.py | 1460 +++++++++++++++++++++---- bob/learn/em/test/test_jfa.py | 100 ++ bob/learn/em/test/test_jfa_trainer.py | 157 +-- 4 files changed, 1445 insertions(+), 276 deletions(-) create mode 100644 bob/learn/em/test/test_jfa.py diff --git a/bob/learn/em/__init__.py b/bob/learn/em/__init__.py index eb5af32..669a8a3 100644 --- a/bob/learn/em/__init__.py +++ b/bob/learn/em/__init__.py @@ -5,7 +5,7 @@ from .kmeans import KMeansMachine from .linear_scoring import linear_scoring # noqa: F401 from .wccn import WCCN from .whitening import Whitening -from .factor_analysis import ISVMachine +from .factor_analysis import ISVMachine, JFAMachine def get_config(): @@ -30,6 +30,6 @@ def __appropriate__(*args): __appropriate__( - KMeansMachine, GMMMachine, GMMStats, WCCN, Whitening, ISVMachine + KMeansMachine, GMMMachine, GMMStats, WCCN, Whitening, ISVMachine, JFAMachine ) __all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 01a1545..06d868c 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -4,14 +4,12 @@ import logging - import numpy as np -import scipy.spatial.distance from sklearn.base import BaseEstimator +from . import linear_scoring logger = logging.getLogger(__name__) -import bob.core def mult_along_axis(A, B, axis): @@ -50,12 +48,47 @@ def mult_along_axis(A, B, axis): class FactorAnalysisBase(BaseEstimator): """ - GMM Factor Analysis base class + Factor Analysis base class. + This class is not intended to be used directly, but rather to be inherited from. + For more information check [McCool2013]_ . + + + Parameters + ---------- + + ubm: :py:class:`bob.learn.em.GMMMachine` + A trained UBM (Universal Background Model) + + r_U: int + Dimension of the subspace U + + r_V: int + Dimension of the subspace V + + em_iterations: int + Number of EM iterations + + relevance_factor: float + Factor analysis relevance factor + + seed: int + Seed for the random number generator + """ - def __init__(self, ubm, r_U, r_V=None, relevance_factor=4.0): + def __init__( + self, + ubm, + r_U, + r_V=None, + relevance_factor=4.0, + em_iterations=10, + seed=0, + ): self.ubm = ubm + self.em_iterations = em_iterations + self.seed = seed # axis 1 dimensions of U and V self.r_U = r_U @@ -91,6 +124,48 @@ class FactorAnalysisBase(BaseEstimator): """ return self.ubm.variances.flatten() + @property + def U(self): + """An alias for `_U`.""" + return self._U + + @U.setter + def U(self, value): + U_shape = (self.supervector_dimension, self.r_U) + if value.shape != U_shape: + raise ValueError( + f"U must be a numpy array of shape {U_shape}, but a matrix of shape {value.shape} was provided." + ) + self._U = value + + @property + def D(self): + """An alias for `_D`.""" + return self._D + + @D.setter + def D(self, value): + D_shape = (self.supervector_dimension,) + if value.shape != D_shape: + raise ValueError( + f"D must be a numpy array of shape {D_shape}, but a matrix of shape {value.shape} was provided." + ) + self._D = value + + @property + def V(self): + """An alias for `_V`.""" + return self._V + + @V.setter + def V(self, value): + V_shape = (self.supervector_dimension, self.r_V) + if value.shape != V_shape: + raise ValueError( + f"V must be a numpy array of shape {V_shape}, but a matrix of shape {value.shape} was provided." + ) + self._V = value + def estimate_number_of_classes(self, y): """ Estimates the number of classes given the labels @@ -133,58 +208,52 @@ class FactorAnalysisBase(BaseEstimator): """ Create the state matrices U, V and D - U: (n_gaussians*feature_dimension, r_U) represents the session variability matrix (within-class variability) + Returns + ------- + + U: (n_gaussians*feature_dimension, r_U) represents the session variability matrix (within-class variability) - V: (n_gaussians*feature_dimension, r_V) represents the session variability matrix (between-class variability) + V: (n_gaussians*feature_dimension, r_V) represents the session variability matrix (between-class variability) - D: (n_gaussians*feature_dimension) represents the client offset vector + D: (n_gaussians*feature_dimension) represents the client offset vector """ + if self.seed is not None: + np.random.seed(self.seed) U_shape = (self.supervector_dimension, self.r_U) # U matrix is initialized using a normal distribution - # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L72 - # TODO: Temporary workaround, so I can reuse the test cases - if isinstance(self.seed, bob.core.random.mt19937): - self.U = bob.core.random.variate_generator( - bob.core.random.mt19937(0), - bob.core.random.normal("float64", mean=0, sigma=1), - )(shape=U_shape) - else: - # Assuming that the seed is an integer - self.U = np.random.normal(scale=1.0, loc=0.0, size=U_shape) + self._U = np.random.normal(scale=1.0, loc=0.0, size=U_shape) # D matrix is initialized as `D = sqrt(variance(UBM) / relevance_factor)` - self.D = np.sqrt(self.variance_supervector / self.relevance_factor) + self._D = np.sqrt(self.variance_supervector / self.relevance_factor) # V matrix (or between-class variation matrix) # TODO: so far not doing JFA - self.V = None - - def _get_statistics_by_class_id(self, X, y, i): - """ - Returns the statistics for a given class - """ - X = np.array(X) - return list(X[np.where(np.array(y) == i)[0]]) - - #################### Estimating U and x ###################### + if self.r_V is not None: + V_shape = (self.supervector_dimension, self.r_V) + self._V = np.random.normal(scale=1.0, loc=0.0, size=V_shape) + else: + self._V = 0 - def _computeUVD(self): + def _sum_n_statistics(self, X, y): """ - Precomputing `U.T @ inv(Sigma)`. - See Eq 37 + Accumulates the 0th statistics for each client - TODO: I have to see if worth to keeping this cache + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics of each sample - """ + y: list of ints + List of corresponding labels - return self.U.T / self.variance_supervector + Returns + ------- + n_acc: array + (n_classes, n_gaussians) representing the accumulated 0th order statistics - def _sum_n_statistics(self, X, y): - """ - Accumulates the 0th statistics for each client """ # 0th order stats n_acc = np.zeros( @@ -201,6 +270,19 @@ class FactorAnalysisBase(BaseEstimator): def _sum_f_statistics(self, X, y): """ Accumulates the 1st order statistics for each client + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + + y: list of ints + List of corresponding labels + + Returns + ------- + f_acc: array + (n_classes, n_gaussians, feature_dimension) representing the accumulated 1st order statistics + """ # 1st order stats @@ -218,9 +300,43 @@ class FactorAnalysisBase(BaseEstimator): return f_acc - def _compute_id_plus_prod_ih(self, x_i, y_i, UProd): + def _get_statistics_by_class_id(self, X, y, i): + """ + Returns the statistics for a given class + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + + y: list of ints + List of corresponding labels + + i: int + Class id to return the statistics for + """ + X = np.array(X) + return list(X[np.where(np.array(y) == i)[0]]) + + #################### Estimating U and x ###################### + + def _compute_id_plus_u_prod_ih(self, x_i, UProd): """ - Computes ( I+Ut*diag(sigma)^-1*Ni*U)^-1) + Computes ( I+Ut*diag(sigma)^-1*Ni*U)^-1 + See equation (29) in [McCool2013]_ + + Parameters + ---------- + x_i: :py:class:`bob.learn.em.GMMStats` + Statistics of a single sample + + UProd: array + Matrix containing U_c.T*inv(Sigma_c) @ U_c.T + + Returns + ------- + id_plus_u_prod_ih: array + ( I+Ut*diag(sigma)^-1*Ni*U)^-1 + """ n_i = x_i.n @@ -229,59 +345,138 @@ class FactorAnalysisBase(BaseEstimator): # TODO: make the invertion matrix function as a parameter return np.linalg.inv(I + (UProd * n_i[:, None, None]).sum(axis=0)) - def _computefn_x_ih(self, x_i, y_i, latent_z=None): + def _computefn_x_ih(self, x_i, latent_z_i=None, latent_y_i=None): """ - Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i}) + Computes Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i} - V*y_{i}) + Check equation (29) in [McCool2013]_ + + Parameters + ---------- + x_i: :py:class:`bob.learn.em.GMMStats` + Statistics of a single sample + + latent_z_i: array + E[z_i] for class `i` + + latent_y_i: array + E[y_i] for class `i` + """ f_i = x_i.sum_px n_i = x_i.n n_ic = np.repeat(n_i, self.supervector_dimension // 2) + V = self._V ## N_ih*( m + D*z) # z is zero when the computation flow comes from update_X - if latent_z is None: + if latent_z_i is None: # Fn_x_ih = N_{i,h}*(o_{i,h} - m) fn_x_ih = f_i.flatten() - n_ic * (self.mean_supervector) else: # code goes here when the computation flow comes from compute_acculators # Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i}) fn_x_ih = f_i.flatten() - n_ic * ( - self.mean_supervector + self.D * latent_z[y_i] + self.mean_supervector + self._D * latent_z_i ) """ - # JFA Part (eq 33) - const blitz::Array<double, 1> &y = m_y[id]; - std::cout << "V" << y << std::endl; - bob::math::prod(V, y, m_tmp_CD_b); - m_cache_Fn_x_ih -= m_tmp_CD * m_tmp_CD_b; + # JFA Part (eq 29) """ + V_dot_v = V @ latent_y_i if latent_y_i is not None else 0 + fn_x_ih -= n_ic * V_dot_v if latent_y_i is not None else 0 return fn_x_ih - def update_x(self, X, y, UProd, latent_x, latent_z=None): + def update_x(self, X, y, UProd, latent_x, latent_y=None, latent_z=None): """ - Computes the accumulators U_a1, U_a2 for U - U = A2 * A1^-1 + Computes a new math:`E[x]` See equation (29) in [McCool2013]_ + + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + UProd: array + Matrix containing U_c.T*inv(Sigma_c) @ U_c.T + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + Returns + ------- + Returns the new latent_x + """ - # U.T @ inv(Sigma) - UTinvSigma = self._computeUVD() - # UProd = self.compute_uprod() + # U.T @ inv(Sigma) - See Eq(37) + UTinvSigma = self._U.T / self.variance_supervector session_offsets = np.zeros(self.estimate_number_of_classes(y)) # For each sample for x_i, y_i in zip(X, y): - id_plus_prod_ih = self._compute_id_plus_prod_ih(x_i, y_i, UProd) - fn_x_ih = self._computefn_x_ih(x_i, y_i, latent_z) + id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) + latent_z_i = latent_z[y_i] if latent_z is not None else None + latent_y_i = latent_y[y_i] if latent_y is not None else None + + fn_x_ih = self._computefn_x_ih( + x_i, latent_z_i=latent_z_i, latent_y_i=latent_y_i + ) latent_x[y_i][:, int(session_offsets[y_i])] = id_plus_prod_ih @ ( UTinvSigma @ fn_x_ih ) session_offsets[y_i] += 1 return latent_x - def compute_uprod(self): + def update_U(self, acc_U_A1, acc_U_A2): + """ + Update rule for U + + Parameters + ---------- + + acc_U_A1: array + Accumulated statistics for U_A1(n_gaussians, r_U, r_U) + + acc_U_A2: array + Accumulated statistics for U_A2(n_gaussians* feature_dimention, r_U) + + """ + + # Inverting A1 over the zero axis + # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy + inv_A1 = np.linalg.inv(acc_U_A1) + + # Iterating over the gaussians to update U + + for c in range(self.ubm.n_gaussians): + + U_c = ( + acc_U_A2[ + c + * self.feature_dimension : (c + 1) + * self.feature_dimension, + :, + ] + @ inv_A1[c, :, :] + ) + self._U[ + c * self.feature_dimension : (c + 1) * self.feature_dimension, + :, + ] = U_c + + def _compute_uprod(self): """ Computes U_c.T*inv(Sigma_c) @ U_c.T @@ -291,7 +486,7 @@ class FactorAnalysisBase(BaseEstimator): UProd = np.zeros((self.ubm.n_gaussians, self.r_U, self.r_U)) for c in range(self.ubm.n_gaussians): # U_c.T - U_c = self.U[ + U_c = self._U[ c * self.feature_dimension : (c + 1) * self.feature_dimension, : ] sigma_c = self.ubm.variances[c].flatten() @@ -299,7 +494,52 @@ class FactorAnalysisBase(BaseEstimator): return UProd - def compute_accumulators_U(self, X, y, UProd, latent_x, latent_z): + def compute_accumulators_U(self, X, y, UProd, latent_x, latent_y, latent_z): + """ + Computes the accumulators (A1 and A2) for the U matrix. + This is useful for parallelization purposes. + + The accumulators are defined as + + :math:`A_1 = \sum\limits_{i=1}^{I}\sum\limits_{h=1}^{H}N_{i,h,c}E(x_{i,h,c} x^{\top}_{i,h,c})` + + + :math:`A_2 = \sum\limits_{i=1}^{I}\sum\limits_{h=1}^{H}N_{i,h,c}(o_{i,h} - \mu_c -D_{c}z_{i,c} -V_{c}y_{i,c} )E[x_{i,h}]^{\top}` + + + More information, please, check the technical notes attached + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + UProd: array + Matrix containing U_c.T*inv(Sigma_c) @ U_c.T + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + Returns + ------- + acc_U_A1: + (n_gaussians, r_U, r_U) A1 accumulator + + acc_U_A2: + (n_gaussians* feature_dimention, r_U) A2 accumulator + + + """ + ## U accumulators acc_U_A1 = np.zeros((self.ubm.n_gaussians, self.r_U, self.r_U)) acc_U_A2 = np.zeros((self.supervector_dimension, self.r_U)) @@ -310,8 +550,12 @@ class FactorAnalysisBase(BaseEstimator): for session_index, x_i in enumerate( self._get_statistics_by_class_id(X, y, y_i) ): - id_plus_prod_ih = self._compute_id_plus_prod_ih(x_i, y_i, UProd) - fn_x_ih = self._computefn_x_ih(x_i, y_i, latent_z) + id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) + latent_z_i = latent_z[y_i] if latent_z is not None else None + latent_y_i = latent_y[y_i] if latent_y is not None else None + fn_x_ih = self._computefn_x_ih( + x_i, latent_y_i=latent_y_i, latent_z_i=latent_z_i + ) latent_x_i = latent_x[y_i][:, session_index] id_plus_prod_ih += ( @@ -332,41 +576,68 @@ class FactorAnalysisBase(BaseEstimator): #################### Estimating D and z ###################### - def update_z(self, X, y, latent_x, latent_z, n_acc, f_acc): + def update_z(self, X, y, latent_x, latent_y, latent_z, n_acc, f_acc): """ - Equation 38 + Computes a new math:`E[z]` See equation (30) in [McCool2013]_ + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + Returns + ------- + Returns the new latent_z + """ # Precomputing - # self.D.T / sigma - dt_inv_sigma = self.D / self.variance_supervector - # self.D.T / sigma * self.D - dt_inv_sigma_d = dt_inv_sigma * self.D + # self._D.T / sigma + dt_inv_sigma = self._D / self.variance_supervector + # self._D.T / sigma * self._D + dt_inv_sigma_d = dt_inv_sigma * self._D # for each class for y_i in set(y): - id_plus_d_prod = self.computeIdPlusDProd_i( + id_plus_d_prod = self._compute_id_plus_d_prod_i( dt_inv_sigma_d, n_acc[y_i] ) - # X_i = X[y == y_i] # Getting the statistics of the current class X_i = self._get_statistics_by_class_id(X, y, y_i) - fn_z_i = self.compute_fn_z_i( - X_i, y_i, latent_x, n_acc[y_i], f_acc[y_i] + latent_x_i = latent_x[y_i] + + latent_y_i = latent_y[y_i] if latent_y is not None else None + + fn_z_i = self._compute_fn_z_i( + X_i, latent_x_i, latent_y_i, n_acc[y_i], f_acc[y_i] ) - latent_z[y_i] = self.updateZ_i(id_plus_d_prod, dt_inv_sigma, fn_z_i) + latent_z[y_i] = id_plus_d_prod * dt_inv_sigma * fn_z_i return latent_z - def updateZ_i(self, id_plus_d_prod, dt_inv_sigma, fn_z_i): - """ - // Computes zi = Azi * D^T.Sigma^-1 * Fn_zi - """ - return id_plus_d_prod * dt_inv_sigma * fn_z_i - - def computeIdPlusDProd_i(self, dt_inv_sigma_d, n_acc_i): + def _compute_id_plus_d_prod_i(self, dt_inv_sigma_d, n_acc_i): """ Computes: (I+Dt*diag(sigma)^-1*Ni*D)^-1 + See equation (31) in [McCool2013]_ Parameters ---------- @@ -383,7 +654,7 @@ class FactorAnalysisBase(BaseEstimator): id_plus_d_prod = np.ones(tmp_CD.shape) + dt_inv_sigma_d * tmp_CD return 1 / id_plus_d_prod - def compute_fn_z_i(self, X_i, y_i, latent_x, n_acc_i, f_acc_i): + def _compute_fn_z_i(self, X_i, latent_x_i, latent_y_i, n_acc_i, f_acc_i): """ Compute Fn_z_i = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - V*y_{i} - U*x_{i,h}) (Normalised first order statistics) @@ -394,33 +665,112 @@ class FactorAnalysisBase(BaseEstimator): """ - U = self.U - V = self.V # Not doing the JFA + U = self._U + V = self._V - latent_X_i = latent_x[y_i] m = self.mean_supervector - # y = self.y[i] # Not doing JFA - tmp_CD = np.repeat(n_acc_i, self.supervector_dimension // 2) - ### NOT DOING JFA - # bob::math::prod(V, y, m_tmp_CD_b); // m_tmp_CD_b = V * y - # V_dot_v = V@v - V_dot_v = 0 # Not doing JFA + ## JFA session part + V_dot_v = V @ latent_y_i if latent_y_i is not None else 0 + # m_cache_Fn_z_i = Fi - m_tmp_CD * (m + m_tmp_CD_b); // Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - V*y_{i}) - fn_z_i = f_acc_i.flatten() - tmp_CD * (m - V_dot_v) + fn_z_i = f_acc_i.flatten() - tmp_CD * (m + V_dot_v) # Looping over the sessions for session_id in range(len(X_i)): n_i = X_i[session_id].n tmp_CD = np.repeat(n_i, self.supervector_dimension // 2) - x_i_h = latent_X_i[:, session_id] + x_i_h = latent_x_i[:, session_id] fn_z_i -= tmp_CD * (U @ x_i_h) return fn_z_i + def compute_accumulators_D( + self, X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ): + """ + Compute the acumulators for the D matrix + + The accumulators are defined as + + :math:`A_1 = \sum\limits_{i=1}^{I}E[z_{i,c}z^{\top}_{i,c}]` + + + :math:`A_2 = \sum\limits_{i=1}^{I} \Bigg[\sum\limits_{h=1}^{H}N_{i,h,c}(o_{i,h} - \mu_c -U_{c}x_{i,h,c} -V_{c}y_{i,c} )\Bigg]E[z_{i}]^{\top}` + + + More information, please, check the technical notes attached + + + Parameters + ---------- + + X: array + Input data + + y: array + Class labels + + latent_z: array + E(z) latent variable + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + Returns + ------- + acc_D_A1: + (n_gaussians* feature_dimention) A1 accumulator + + acc_D_A2: + (n_gaussians* feature_dimention) A2 accumulator + + """ + + acc_D_A1 = np.zeros((self.supervector_dimension,)) + acc_D_A2 = np.zeros((self.supervector_dimension,)) + + # Precomputing + # self._D.T / sigma + dt_inv_sigma = self._D / self.variance_supervector + # self._D.T / sigma * self._D + dt_inv_sigma_d = dt_inv_sigma * self._D + + # Loops over all people + for y_i in set(y): + + id_plus_d_prod = self._compute_id_plus_d_prod_i( + dt_inv_sigma_d, n_acc[y_i] + ) + X_i = self._get_statistics_by_class_id(X, y, y_i) + latent_x_i = latent_x[y_i] + + latent_y_i = latent_y[y_i] if latent_y is not None else None + + fn_z_i = self._compute_fn_z_i( + X_i, latent_x_i, latent_y_i, n_acc[y_i], f_acc[y_i] + ) + + tmp_CD = np.repeat(n_acc[y_i], self.supervector_dimension // 2) + acc_D_A1 += ( + id_plus_d_prod + latent_z[y_i] * latent_z[y_i] + ) * tmp_CD + acc_D_A2 += fn_z_i * latent_z[y_i] + + return acc_D_A1, acc_D_A2 + def initialize_XYZ(self, y): """ Initialize E[x], E[y], E[z] state variables @@ -450,7 +800,11 @@ class FactorAnalysisBase(BaseEstimator): ) ) - latent_y = None + latent_y = ( + np.zeros((self.estimate_number_of_classes(y), self.r_V)) + if self.r_V and self.r_V > 0 + else None + ) latent_z = np.zeros( (self.estimate_number_of_classes(y), self.supervector_dimension) @@ -458,74 +812,384 @@ class FactorAnalysisBase(BaseEstimator): return latent_x, latent_y, latent_z - # latent_x, latent_y, latent_z = self.initialize_XYZ(y) + #################### Estimating V and y ###################### + + def update_y(self, X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc): + """ + Computes a new math:`E[y]` See equation (30) in [McCool2013]_ + Parameters + ---------- -class ISVMachine(FactorAnalysisBase): - """ - Implements the Interssion Varibility Modelling hypothesis on top of GMMs + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics - Inter-Session Variability (ISV) modeling is a session variability modeling technique built on top of the Gaussian mixture modeling approach. - It hypothesizes that within-class variations are embedded in a linear subspace in the GMM means subspace and these variations can be suppressed - by an offset w.r.t each mean during the MAP adaptation. + y: list of ints + List of corresponding labels - """ + VProd: array + Matrix representing V_c.T*inv(Sigma_c) @ V_c.T - def __init__(self, ubm, r_U, em_iterations, relevance_factor=4.0, seed=0): - self.r_U = r_U - self.seed = seed - self.em_iterations = em_iterations - super(ISVMachine, self).__init__( - ubm, r_U=r_U, relevance_factor=relevance_factor - ) + latent_x: array + E(x) latent variable - def initialize(self, X, y): - return super(ISVMachine, self).initialize(X, y) + latent_y: array + E(y) latent variable - def e_step(self, X, y, n_acc, f_acc): - """ - E-step of the EM algorithm - """ - # self.initialize_XYZ(y) - UProd = self.compute_uprod() - latent_x, latent_y, latent_z = self.initialize_XYZ(y) + latent_z: array + E(z) latent variable - latent_x = self.update_x(X, y, UProd, latent_x) - latent_z = self.update_z(X, y, latent_x, latent_z, n_acc, f_acc) - acc_U_A1, acc_U_A2 = self.compute_accumulators_U( - X, y, UProd, latent_x, latent_z - ) + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) - return acc_U_A1, acc_U_A2 + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) - def m_step(self, acc_U_A1, acc_U_A2): - """ - M-step of the EM algorithm """ + # V.T / sigma + VTinvSigma = self._V.T / self.variance_supervector - # Inverting A1 over the zero axis - # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy - inv_A1 = np.linalg.inv(acc_U_A1) + # Loops over the labels + for label in range(self.estimate_number_of_classes(y)): + id_plus_v_prod_i = self._compute_id_plus_vprod_i( + n_acc[label], VProd + ) + X_i = self._get_statistics_by_class_id(X, y, label) + fn_y_i = self._compute_fn_y_i( + X_i, + latent_x[label], + latent_z[label], + n_acc[label], + f_acc[label], + ) + latent_y[label] = (VTinvSigma @ fn_y_i) @ id_plus_v_prod_i + return latent_y - # Iterating over the gaussians to update U + def _compute_id_plus_vprod_i(self, n_acc_i, VProd): + """ + Computes: (I+Vt*diag(sigma)^-1*Ni*V)^-1 (see Eq. (30) in [McCool2013]_) + + Parameters + ---------- + + n_acc_i: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + VProd: array + Matrix representing V_c.T*inv(Sigma_c) @ V_c.T + + """ + I = np.eye(self.r_V, self.r_V) + + # TODO: make the invertion matrix function as a parameter + return np.linalg.inv(I + (VProd * n_acc_i[:, None, None]).sum(axis=0)) + + def _compute_vprod(self): + """ + Computes V_c.T*inv(Sigma_c) @ V_c.T + + + ### https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/FABaseTrainer.cpp#L193 + """ + + VProd = np.zeros((self.ubm.n_gaussians, self.r_V, self.r_V)) for c in range(self.ubm.n_gaussians): + # V_c.T + V_c = self._V[ + c * self.feature_dimension : (c + 1) * self.feature_dimension, : + ] + sigma_c = self.ubm.variances[c].flatten() + VProd[c, :, :] = V_c.T @ (V_c.T / sigma_c).T - U_c = ( - acc_U_A2[ - c - * self.feature_dimension : (c + 1) - * self.feature_dimension, - :, - ] - @ inv_A1[c, :, :] + return VProd + + def compute_accumulators_V( + self, X, y, VProd, n_acc, f_acc, latent_x, latent_y, latent_z + ): + """ + Computes the accumulators for the update of V matrix + The accumulators are defined as + + :math:`A_1 = \sum\limits_{i=1}^{I}E[y_{i,c}y^{\top}_{i,c}]` + + + :math:`A_2 = \sum\limits_{i=1}^{I} \Bigg[\sum\limits_{h=1}^{H}N_{i,h,c}(o_{i,h} - \mu_c -U_{c}x_{i,h,c} -D_{c}z_{i,c} )\Bigg]E[y_{i}]^{\top}` + + + More information, please, check the technical notes attached + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + VProd: array + Matrix representing V_c.T*inv(Sigma_c) @ V_c.T + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + + Returns + ------- + + acc_V_A1: + (n_gaussians, r_V, r_V) A1 accumulator + + acc_V_A2: + (n_gaussians* feature_dimention, r_V) A2 accumulator + + """ + + ## U accumulators + acc_V_A1 = np.zeros((self.ubm.n_gaussians, self.r_V, self.r_V)) + acc_V_A2 = np.zeros((self.supervector_dimension, self.r_V)) + + # Loops over all people + for i in set(y): + n_acc_i = n_acc[i] + f_acc_i = f_acc[i] + X_i = self._get_statistics_by_class_id(X, y, i) + latent_x_i = latent_x[i] + latent_y_i = latent_y[i] + latent_z_i = latent_z[i] + + # Compyting A1 accumulator: \sum_{i=1}^{N}(E(y_i_c @ y_i_c.T)) + id_plus_prod_v_i = self._compute_id_plus_vprod_i(n_acc_i, VProd) + id_plus_prod_v_i += ( + latent_y_i[:, np.newaxis] @ latent_y_i[:, np.newaxis].T ) - self.U[ - c * self.feature_dimension : (c + 1) * self.feature_dimension, - :, - ] = U_c - pass + acc_V_A1 += mult_along_axis( + id_plus_prod_v_i[np.newaxis].repeat( + self.ubm.n_gaussians, axis=0 + ), + n_acc_i, + axis=0, + ) + + # Computing A2 accumulator: \sum_{i=1}^{N}( \sum_{h=1}^{H}(N_i_h_c (o_i_h, - m_c - D_c*z_i_c - U_c*x_i_h_c))@ E(y_i).T ) + fn_y_i = self._compute_fn_y_i( + X_i, + latent_x_i=latent_x_i, + latent_z_i=latent_z_i, + n_acc_i=n_acc_i, + f_acc_i=f_acc_i, + ) + + acc_V_A2 += fn_y_i[np.newaxis].T @ latent_y_i[:, np.newaxis].T + + return acc_V_A1, acc_V_A2 + + def _compute_fn_y_i(self, X_i, latent_x_i, latent_z_i, n_acc_i, f_acc_i): + """ + // Compute Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - D*z_{i} - U*x_{i,h}) (Normalised first order statistics) + See equation (30) in [McCool2013]_ + + Parameters + ---------- + + X_i: list of :py:class:`bob.learn.em.GMMStats` + List of statistics for a class + + latent_x_i: array + E(x_i) latent variable + + latent_z_i: array + E(z_i) latent variable + + n_acc_i: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc_i: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + + """ + + U = self._U + D = self._D # Not doing the JFA + + m = self.mean_supervector + + # y = self.y[i] # Not doing JFA + + tmp_CD = np.repeat(n_acc_i, self.supervector_dimension // 2) + + fn_y_i = f_acc_i.flatten() - tmp_CD * ( + m - D * latent_z_i + ) # Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - D*z_{i}) + + ### NOT DOING JFA + + # Looping over the sessions of a ;ane; + for session_id in range(len(X_i)): + n_i = X_i[session_id].n + U_dot_x = U @ latent_x_i[:, session_id] + tmp_CD = np.repeat(n_i, self.supervector_dimension // 2) + fn_y_i -= tmp_CD * U_dot_x + + return fn_y_i + + #################################################################################################################### + # Scoring + + def estimate_x(self, X): + + id_plus_us_prod_inv = self._compute_id_plus_us_prod_inv(X) + fn_x = self._compute_fn_x(X) + + # UtSigmaInv * Fn_x = Ut*diag(sigma)^-1 * N*(o - m) + ut_inv_sigma = self._U.T / self.variance_supervector + + return id_plus_us_prod_inv @ (ut_inv_sigma @ fn_x) + + def _compute_id_plus_us_prod_inv(self, X_i): + """ + Computes (Id + U^T.Sigma^-1.U.N_{i,h}.U)^-1 = + + Parameters + ---------- + + X_i: list of :py:class:`bob.learn.em.GMMStats` + List of statistics for a class + """ + I = np.eye(self.r_U, self.r_U) + + Uc = self._U.reshape( + (self.ubm.n_gaussians, self.feature_dimension, self.r_U) + ) + + UcT = np.transpose(Uc, axes=(0, 2, 1)) + + sigma_c = np.reshape( + self.variance_supervector, + (self.ubm.n_gaussians, self.feature_dimension), + ) + + n_i_c = np.expand_dims(X_i.n[:, np.newaxis], axis=2) + + id_plus_us_prod_inv = I + ( + ((UcT / sigma_c[:, np.newaxis]) @ Uc) * n_i_c + ).sum(axis=0) + id_plus_us_prod_inv = np.linalg.inv(id_plus_us_prod_inv) + + return id_plus_us_prod_inv + + def _compute_fn_x(self, X_i): + """ + Compute Fn_x = sum_{sessions h}(N*(o - m) (Normalised first order statistics) + + Parameters + ---------- + + X_i: list of :py:class:`bob.learn.em.GMMStats` + List of statistics for a class + + """ + + n = X_i.n[:, np.newaxis] + f = X_i.sum_px + + fn_x = f - self.ubm.means * n + + return fn_x.flatten() + + +class ISVMachine(FactorAnalysisBase): + """ + Implements the Interssion Varibility Modelling hypothesis on top of GMMs + + Inter-Session Variability (ISV) modeling is a session variability modeling technique built on top of the Gaussian mixture modeling approach. + It hypothesizes that within-class variations are embedded in a linear subspace in the GMM means subspace and these variations can be suppressed + by an offset w.r.t each mean during the MAP adaptation. + For more information check [McCool2013]_ + + Parameters + ---------- + + ubm: :py:class:`bob.learn.em.GMMMachine` + A trained UBM (Universal Background Model) + + r_U: int + Dimension of the subspace U + + em_iterations: int + Number of EM iterations + + relevance_factor: float + Factor analysis relevance factor + + seed: int + Seed for the random number generator + + """ + + def __init__(self, ubm, r_U, em_iterations, relevance_factor=4.0, seed=0): + super(ISVMachine, self).__init__( + ubm, + r_U=r_U, + relevance_factor=relevance_factor, + em_iterations=em_iterations, + seed=seed, + ) + + def initialize(self, X, y): + return super(ISVMachine, self).initialize(X, y) + + def e_step(self, X, y, n_acc, f_acc): + """ + E-step of the EM algorithm + """ + # self.initialize_XYZ(y) + UProd = self._compute_uprod() + latent_x, _, latent_z = self.initialize_XYZ(y) + latent_y = None + + latent_x = self.update_x(X, y, UProd, latent_x) + latent_z = self.update_z( + X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ) + acc_U_A1, acc_U_A2 = self.compute_accumulators_U( + X, y, UProd, latent_x, latent_y, latent_z + ) + + return acc_U_A1, acc_U_A2 + + def m_step(self, acc_U_A1, acc_U_A2): + """ + ISV M-step. + This updates `U` matrix + + Parameters + ---------- + + acc_U_A1: array + Accumulated statistics for U_A1(n_gaussians, r_U, r_U) + + acc_U_A2: array + Accumulated statistics for U_A2(n_gaussians* feature_dimention, r_U) + + """ + + self.update_U(acc_U_A1, acc_U_A2) def fit(self, X, y): """ @@ -545,11 +1209,16 @@ class ISVMachine(FactorAnalysisBase): """ - self.create_UVD(y) - self.initialize(X, y) + # In case those variables are already set + if not hasattr(self, "_U") or not hasattr(self, "_D"): + self.create_UVD() + + # TODO: Point of parallelism + n_acc, f_acc = self.initialize(X, y) for i in range(self.em_iterations): logger.info("U Training: Iteration %d", i) - acc_U_A1, acc_U_A2 = self.e_step(X, y) + # TODO: Point of parallelism + acc_U_A1, acc_U_A2 = self.e_step(X, y, n_acc, f_acc) self.m_step(acc_U_A1, acc_U_A2) return self @@ -560,8 +1229,9 @@ class ISVMachine(FactorAnalysisBase): Parameters ---------- - X : numpy.ndarray - Nxd features of N GMM statistics + X : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be enrolled + iterations : int Number of iterations to perform @@ -576,65 +1246,174 @@ class ISVMachine(FactorAnalysisBase): n_acc = self._sum_n_statistics(X, y=y) f_acc = self._sum_f_statistics(X, y=y) - UProd = self.compute_uprod() + UProd = self._compute_uprod() latent_x, _, latent_z = self.initialize_XYZ(y) - + latent_y = None for i in range(iterations): logger.info("Enrollment: Iteration %d", i) - # latent_x = self.update_x(X, y, UProd, [np.zeros((2, 2))]) - latent_x = self.update_x(X, y, UProd, latent_x, latent_z) - latent_z = self.update_z(X, y, latent_x, latent_z, n_acc, f_acc) + latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z) + latent_z = self.update_z( + X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ) return latent_z + def score(self, latent_z, data): + """ + Computes the ISV score + + Parameters + ---------- + latent_z : numpy.ndarray + Latent representation of the client (E[z_i]) + + data : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be scored + + Returns + ------- + score : float + The linear scored + + """ + x = self.estimate_x(data) + Ux = self._U @ x + + # TODO: I don't know why this is not the enrolled model + # Here I am just reproducing the C++ implementation + # m + Dz + z = self.D * latent_z + self.mean_supervector + + return linear_scoring( + z.reshape((self.ubm.n_gaussians, self.feature_dimension)), + self.ubm, + data, + Ux.reshape((self.ubm.n_gaussians, self.feature_dimension)), + frame_length_normalization=True, + )[0] + class JFAMachine(FactorAnalysisBase): """ - JFA + Joint Factor Analysis (JFA) is an extension of ISV. Besides the + within-class assumption (modeled with :math:`U`), it also hypothesize that + between class variations are embedded in a low rank rectangular matrix + :math:`V`. In the supervector notation, this modeling has the following shape: + :math:`\mu_{i, j} = m + Ux_{i, j} + Vy_{i} + D_z{i}`. + + For more information check [McCool2013]_ + + Parameters + ---------- + + ubm: :py:class:`bob.learn.em.GMMMachine` + A trained UBM (Universal Background Model) + + r_U: int + Dimension of the subspace U + + r_V: int + Dimension of the subspace V + + em_iterations: int + Number of EM iterations + + relevance_factor: float + Factor analysis relevance factor + + seed: int + Seed for the random number generator + """ - def __init__(self, ubm, r_U, em_iterations, relevance_factor=4.0, seed=0): - self.r_U = r_U - self.seed = seed - self.em_iterations = em_iterations - super(ISVMachine, self).__init__( - ubm, r_U=r_U, relevance_factor=relevance_factor + def __init__( + self, ubm, r_U, r_V, em_iterations, relevance_factor=4.0, seed=0 + ): + super(JFAMachine, self).__init__( + ubm, + r_U=r_U, + r_V=r_V, + relevance_factor=relevance_factor, + em_iterations=em_iterations, + seed=seed, ) def initialize(self, X, y): - return super(ISVMachine, self).initialize(X, y) + return super(JFAMachine, self).initialize(X, y) - def e_step(self, X, y, n_acc, f_acc): + def e_step_v(self, X, y, n_acc, f_acc): """ - E-step of the EM algorithm + ISV E-step for the V matrix. + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + n_acc: array + Accumulated 0th-order statistics + + f_acc: array + Accumulated 1st-order statistics + + + Returns + ---------- + + acc_V_A1: array + Accumulated statistics for V_A1(n_gaussians, r_V, r_V) + + acc_V_A2: array + Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) + """ - # self.initialize_XYZ(y) - UProd = self.compute_uprod() + + VProd = self._compute_vprod() + latent_x, latent_y, latent_z = self.initialize_XYZ(y) - latent_x = self.update_x(X, y, UProd, latent_x) - latent_z = self.update_z(X, y, latent_x, latent_z, n_acc, f_acc) - acc_U_A1, acc_U_A2 = self.compute_accumulators_U( - X, y, UProd, latent_x, latent_z + #### UPDATE Y, X AND FINALY Z + + latent_y = self.update_y( + X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc ) - return acc_U_A1, acc_U_A2 + acc_V_A1, acc_V_A2 = self.compute_accumulators_V( + X, y, VProd, n_acc, f_acc, latent_x, latent_y, latent_z + ) - def m_step(self, acc_U_A1, acc_U_A2): + return acc_V_A1, acc_V_A2 + + def m_step_v(self, acc_V_A1, acc_V_A2): """ - M-step of the EM algorithm + `V` Matrix M-step. + This updates the `V` matrix + + Parameters + ---------- + + acc_V_A1: array + Accumulated statistics for V_A1(n_gaussians, r_V, r_V) + + acc_V_A2: array + Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) + """ # Inverting A1 over the zero axis # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy - inv_A1 = np.linalg.inv(acc_U_A1) + inv_A1 = np.linalg.inv(acc_V_A1) - # Iterating over the gaussinas to update U + # Iterating over the gaussians to update V for c in range(self.ubm.n_gaussians): - U_c = ( - acc_U_A2[ + V_c = ( + acc_V_A2[ c * self.feature_dimension : (c + 1) * self.feature_dimension, @@ -642,39 +1421,214 @@ class JFAMachine(FactorAnalysisBase): ] @ inv_A1[c, :, :] ) - self.U[ + self._V[ c * self.feature_dimension : (c + 1) * self.feature_dimension, :, - ] = U_c + ] = V_c - pass + def finalize_v(self, X, y, n_acc, f_acc): + """ + Compute for the last time `E[y]` + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + n_acc: array + Accumulated 0th-order statistics + + f_acc: array + Accumulated 1st-order statistics + + Returns + ------- + latent_y: array + E[y] - def fit(self, X, y): """ - Trains the U matrix (session variability matrix) + VProd = self._compute_vprod() + + latent_x, latent_y, latent_z = self.initialize_XYZ(y) + + #### UPDATE Y, X AND FINALY Z + + latent_y = self.update_y( + X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + ) + return latent_y + + def e_step_u(self, X, y, latent_y): + """ + ISV E-step for the U matrix. Parameters ---------- - X : numpy.ndarray - Nxd features of N GMM statistics - y : numpy.ndarray - The input labels, a 1D numpy array of shape (number of samples, ) + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + latent_y: array + E(y) latent variable + + + Returns + ---------- + + acc_U_A1: array + Accumulated statistics for U_A1(n_gaussians, r_U, r_U) + + acc_U_A2: array + Accumulated statistics for U_A2(n_gaussians* feature_dimention, r_U) + + """ + # self.initialize_XYZ(y) + UProd = self._compute_uprod() + latent_x, _, latent_z = self.initialize_XYZ(y) + + latent_x = self.update_x(X, y, UProd, latent_x, latent_y) + + acc_U_A1, acc_U_A2 = self.compute_accumulators_U( + X, y, UProd, latent_x, latent_y, latent_z + ) + + return acc_U_A1, acc_U_A2 + + def m_step_u(self, acc_U_A1, acc_U_A2): + """ + `U` Matrix M-step. + This updates the `U` matrix + + Parameters + ---------- + + acc_V_A1: array + Accumulated statistics for V_A1(n_gaussians, r_V, r_V) + + acc_V_A2: array + Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) + + """ + + self.update_U(acc_U_A1, acc_U_A2) + + def finalize_u( + self, + X, + y, + latent_y, + ): + """ + Compute for the last time `E[x]` + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + latent_y: array + E[y] latent variable Returns ------- - self : object - Returns self. + latent_x: array + E[x] + """ + + UProd = self._compute_uprod() + latent_x, _, _ = self.initialize_XYZ(y) + + latent_x = self.update_x( + X, y, UProd, latent_x=latent_x, latent_y=latent_y + ) + + return latent_x + def e_step_d(self, X, y, latent_x, latent_y, n_acc, f_acc): """ + ISV E-step for the U matrix. - self.create_UVD(y) - self.initialize(X, y) - for i in range(self.em_iterations): - logger.info("U Training: Iteration %d", i) - acc_U_A1, acc_U_A2 = self.e_step(X, y) - self.m_step(acc_U_A1, acc_U_A2) + Parameters + ---------- - return self + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + n_acc: array + Accumulated 0th-order statistics + + f_acc: array + Accumulated 1st-order statistics + + + Returns + ---------- + + acc_D_A1: array + Accumulated statistics for D_A1(n_gaussians* feature_dimension, ) + + acc_D_A2: array + Accumulated statistics for D_A2(n_gaussians* feature_dimension, ) + + """ + + _, _, latent_z = self.initialize_XYZ(y) + + latent_z = self.update_z( + X, + y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + + acc_D_A1, acc_D_A2 = self.compute_accumulators_D( + X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ) + + return acc_D_A1, acc_D_A2 + + def m_step_d(self, acc_D_A1, acc_D_A2): + """ + `D` Matrix M-step. + This updates the `D` matrix + + Parameters + ---------- + + acc_D_A1: array + Accumulated statistics for D_A1(n_gaussians* feature_dimension, ) + + acc_D_A2: array + Accumulated statistics for D_A2(n_gaussians* feature_dimension, ) + + """ + self._D = acc_D_A2 / acc_D_A1 def enroll(self, X, iterations=1): """ @@ -682,15 +1636,16 @@ class JFAMachine(FactorAnalysisBase): Parameters ---------- - X : numpy.ndarray - Nxd features of N GMM statistics + X : list of :py:class:`bob.learn.em.GMMStats` + List of statistics + iterations : int Number of iterations to perform Returns ------- self : object - z + z, y """ # We have only one class for enrollment @@ -698,13 +1653,112 @@ class JFAMachine(FactorAnalysisBase): n_acc = self._sum_n_statistics(X, y=y) f_acc = self._sum_f_statistics(X, y=y) - UProd = self.compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y) + UProd = self._compute_uprod() + VProd = self._compute_vprod() + latent_x, latent_y, latent_z = self.initialize_XYZ(y) for i in range(iterations): logger.info("Enrollment: Iteration %d", i) - # latent_x = self.update_x(X, y, UProd, [np.zeros((2, 2))]) - latent_x = self.update_x(X, y, UProd, latent_x, latent_z) - latent_z = self.update_z(X, y, latent_x, latent_z, n_acc, f_acc) + latent_y = self.update_y( + X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + ) + latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z) + latent_z = self.update_z( + X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ) - return latent_z + return latent_y, latent_z + + def fit(self, X, y): + """ + Trains the U matrix (session variability matrix) + + Parameters + ---------- + X : numpy.ndarray + Nxd features of N GMM statistics + y : numpy.ndarray + The input labels, a 1D numpy array of shape (number of samples, ) + + Returns + ------- + self : object + Returns self. + + """ + + # In case those variables are already set + if ( + not hasattr(self, "_U") + or not hasattr(self, "_V") + or not hasattr(self, "_D") + ): + self.create_UVD() + + # TODO: Point of parallelism + n_acc, f_acc = self.initialize(X, y) + + # Updating V + for i in range(self.em_iterations): + logger.info("V Training: Iteration %d", i) + # TODO: Point of parallelism + acc_V_A1, acc_V_A2 = self.e_step_v(X, y, n_acc, f_acc) + self.m_step_v(acc_V_A1, acc_V_A2) + latent_y = self.finalize_v(X, y, n_acc, f_acc) + + # Updating U + for i in range(self.em_iterations): + logger.info("U Training: Iteration %d", i) + # TODO: Point of parallelism + acc_U_A1, acc_U_A2 = self.e_step_u(X, y, latent_y) + self.m_step_u(acc_U_A1, acc_U_A2) + + latent_x = self.finalize_u(X, y, latent_y) + + # Updating D + for i in range(self.em_iterations): + logger.info("D Training: Iteration %d", i) + # TODO: Point of parallelism + acc_D_A1, acc_D_A2 = self.e_step_d( + X, y, latent_x, latent_y, n_acc, f_acc + ) + self.m_step_d(acc_D_A1, acc_D_A2) + + return self + + def score(self, model, data): + """ + Computes the ISV score + + Parameters + ---------- + latent_z : numpy.ndarray + Latent representation of the client (E[z_i]) + + data : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be scored + + Returns + ------- + score : float + The linear scored + + """ + latent_y = model[0] + latent_z = model[1] + + x = self.estimate_x(data) + Ux = self._U @ x + + # TODO: I don't know why this is not the enrolled model + # Here I am just reproducing the C++ implementation + # m + Vy + Dz + zy = self.V @ latent_y + self.D * latent_z + self.mean_supervector + + return linear_scoring( + zy.reshape((self.ubm.n_gaussians, self.feature_dimension)), + self.ubm, + data, + Ux.reshape((self.ubm.n_gaussians, self.feature_dimension)), + frame_length_normalization=True, + )[0] diff --git a/bob/learn/em/test/test_jfa.py b/bob/learn/em/test/test_jfa.py new file mode 100644 index 0000000..23086de --- /dev/null +++ b/bob/learn/em/test/test_jfa.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# Laurent El Shafey <Laurent.El-Shafey@idiap.ch> +# Tiago Freitas Pereira <tiago.pereira@idiap.ch> +# Tue Jul 19 12:16:17 2011 +0200 +# +# Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland + +import numpy as np +from bob.learn.em import GMMMachine, GMMStats, ISVMachine, JFAMachine +import copy + + +def test_JFAMachine(): + + eps = 1e-10 + + # Creates a UBM + weights = np.array([0.4, 0.6], "float64") + means = np.array([[1, 6, 2], [4, 3, 2]], "float64") + variances = np.array([[1, 2, 1], [2, 1, 2]], "float64") + ubm = GMMMachine(2, 3) + ubm.weights = weights + ubm.means = means + ubm.variances = variances + + # Defines GMMStats + gs = GMMStats(2, 3) + log_likelihood = -3.0 + T = 1 + n = np.array([0.4, 0.6], "float64") + sumpx = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "float64") + sumpxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64") + gs.log_likelihood = log_likelihood + gs.t = T + gs.n = n + gs.sum_px = sumpx + gs.sum_pxx = sumpxx + + # Creates a JFAMachine + U = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64") + V = np.array([[6, 5], [4, 3], [2, 1], [1, 2], [3, 4], [5, 6]], "float64") + d = np.array([0, 1, 0, 1, 0, 1], "float64") + m = JFAMachine(ubm, 2, 2, em_iterations=10) + m.U = U + m.V = V + m.D = d + + # Preparing the model + y = np.array([1, 2], "float64") + z = np.array([3, 4, 1, 2, 0, 1], "float64") + model = [y, z] + + score_ref = -2.111577181208289 + score = m.score(model, gs) + assert abs(score_ref - score) < eps + + +def test_ISVMachine(): + + eps = 1e-10 + + # Creates a UBM + weights = np.array([0.4, 0.6], "float64") + means = np.array([[1, 6, 2], [4, 3, 2]], "float64") + variances = np.array([[1, 2, 1], [2, 1, 2]], "float64") + ubm = GMMMachine(2, 3) + ubm.weights = weights + ubm.means = means + ubm.variances = variances + + # Creates a ISVMachine + U = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64") + # V = numpy.array([[0], [0], [0], [0], [0], [0]], 'float64') + d = np.array([0, 1, 0, 1, 0, 1], "float64") + isv_machine = ISVMachine(ubm, r_U=2, em_iterations=10) + isv_machine.U = U + # base.v = V + isv_machine.D = d + + # Defines GMMStats + gs = GMMStats(2, 3) + log_likelihood = -3.0 + T = 1 + n = np.array([0.4, 0.6], "float64") + sumpx = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "float64") + sumpxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64") + gs.log_likelihood = log_likelihood + gs.t = T + gs.n = n + gs.sum_px = sumpx + gs.sum_pxx = sumpxx + + # Enrolled model + latent_z = np.array([3, 4, 1, 2, 0, 1], "float64") + score = isv_machine.score(latent_z, gs) + score_ref = -3.280498193082100 + + assert abs(score_ref - score) < eps + pass diff --git a/bob/learn/em/test/test_jfa_trainer.py b/bob/learn/em/test/test_jfa_trainer.py index edec5fa..f9e51e4 100644 --- a/bob/learn/em/test/test_jfa_trainer.py +++ b/bob/learn/em/test/test_jfa_trainer.py @@ -7,8 +7,7 @@ # Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland import numpy as np -from bob.learn.em import GMMMachine, GMMStats, ISVMachine -import bob.core +from bob.learn.em import GMMMachine, GMMStats, ISVMachine, JFAMachine import copy # Define Training set and initial values for tests @@ -118,17 +117,17 @@ def test_JFATrainAndEnrol(): # Calls the train function ubm = GMMMachine(2, 3) - ubm.mean_supervector = UBM_MEAN - ubm.variance_supervector = UBM_VAR - mb = JFABase(ubm, 2, 2) - t = JFATrainer() - t.initialize(mb, TRAINING_STATS) - mb.u = M_u - mb.v = M_v - mb.d = M_d - bob.learn.em.train_jfa(t, mb, TRAINING_STATS, initialize=False) - - v_ref = numpy.array( + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + it = JFAMachine(ubm, 2, 2, em_iterations=10) + # n_acc, f_acc = it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + it.U = copy.deepcopy(M_u) + it.V = copy.deepcopy(M_v) + it.D = copy.deepcopy(M_d) + it.fit(TRAINING_STATS_X, TRAINING_STATS_y) + # bob.learn.em.train_jfa(t, mb, TRAINING_STATS, initialize=False) + + v_ref = np.array( [ [0.245364911936476, 0.978133261775424], [0.769646805052223, 0.940070736856596], @@ -139,7 +138,7 @@ def test_JFATrainAndEnrol(): ], "float64", ) - u_ref = numpy.array( + u_ref = np.array( [ [0.049424652628448, 0.060480486336896], [0.178104127464007, 1.884873813495153], @@ -150,7 +149,7 @@ def test_JFATrainAndEnrol(): ], "float64", ) - d_ref = numpy.array( + d_ref = np.array( [ 9.648467e-18, 2.63720683155e-12, @@ -163,15 +162,14 @@ def test_JFATrainAndEnrol(): ) eps = 1e-10 - assert numpy.allclose(mb.v, v_ref, eps) - assert numpy.allclose(mb.u, u_ref, eps) - assert numpy.allclose(mb.d, d_ref, eps) + assert np.allclose(it.V, v_ref, eps) + assert np.allclose(it.U, u_ref, eps) + assert np.allclose(it.D, d_ref, eps) # Calls the enroll function - m = JFAMachine(mb) - Ne = numpy.array([0.1579, 0.9245, 0.1323, 0.2458]).reshape((2, 2)) - Fe = numpy.array( + Ne = np.array([0.1579, 0.9245, 0.1323, 0.2458]).reshape((2, 2)) + Fe = np.array( [ 0.1579, 0.1925, @@ -195,10 +193,10 @@ def test_JFATrainAndEnrol(): gse2.sum_px = Fe[:, 1].reshape(2, 3) gse = [gse1, gse2] - t.enroll(m, gse, 5) + latent_y, latent_z = it.enroll(gse, 5) - y_ref = numpy.array([0.555991469319657, 0.002773650670010], "float64") - z_ref = numpy.array( + y_ref = np.array([0.555991469319657, 0.002773650670010], "float64") + z_ref = np.array( [ 8.2228e-20, 3.15216909492e-13, @@ -209,10 +207,12 @@ def test_JFATrainAndEnrol(): ], "float64", ) - assert numpy.allclose(m.y, y_ref, eps) - assert numpy.allclose(m.z, z_ref, eps) + + assert np.allclose(latent_y, y_ref, eps) + assert np.allclose(latent_z, z_ref, eps) # Testing exceptions + """ nose.tools.assert_raises(RuntimeError, t.initialize, mb, [1, 2, 2]) nose.tools.assert_raises(RuntimeError, t.initialize, mb, [[1, 2, 2]]) nose.tools.assert_raises(RuntimeError, t.e_step_u, mb, [1, 2, 2]) @@ -231,43 +231,7 @@ def test_JFATrainAndEnrol(): nose.tools.assert_raises(RuntimeError, t.m_step_d, mb, [[1, 2, 2]]) nose.tools.assert_raises(RuntimeError, t.enroll, m, [[1, 2, 2]], 5) - - -def test_ISVTrainInitialize(): - - # Check that the initialization is consistent and using the rng (cf. issue #118) - eps = 1e-10 - - # UBM GMM - ubm = GMMMachine(2, 3) - ubm.means = UBM_MEAN.reshape((2, 3)) - ubm.variances = UBM_VAR.reshape((2, 3)) - - ## ISV - # ib = ISVBase(ubm, 2) - # first round - # rng = bob.core.random.mt19937(0) - it = ISVMachine(ubm, 2) - # it.rng = rng - - it.initialize( - TRAINING_STATS_X, TRAINING_STATS_y, seed=bob.core.random.mt19937(0) - ) - u1 = copy.deepcopy(it.U) - d1 = copy.deepcopy(it.D) - - # second round - rng = bob.core.random.mt19937(0) - it.rng = rng - it.initialize( - TRAINING_STATS_X, TRAINING_STATS_y, seed=bob.core.random.mt19937(0) - ) - u2 = it.U - d2 = it.D - - assert np.allclose(u1, u2, eps) - assert np.allclose(d1, d2, eps) - pass + """ def test_ISVTrainAndEnrol(): @@ -320,16 +284,10 @@ def test_ISVTrainAndEnrol(): r_U=2, relevance_factor=4.0, em_iterations=10, - seed=bob.core.random.mt19937(0), ) - n_acc, f_acc = it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) - it.U = M_u - for i in range(10): - acc_U_A1, acc_U_A2 = it.e_step( - TRAINING_STATS_X, TRAINING_STATS_y, n_acc, f_acc - ) - it.m_step(acc_U_A1, acc_U_A2) + it.U = copy.deepcopy(M_u) + it = it.fit(TRAINING_STATS_X, TRAINING_STATS_y) assert np.allclose(it.D, d_ref, eps) assert np.allclose(it.U, u_ref, eps) @@ -373,3 +331,60 @@ def test_ISVTrainAndEnrol(): # nose.tools.assert_raises(RuntimeError, t.e_step, mb, [1, 2, 2]) # nose.tools.assert_raises(RuntimeError, t.e_step, mb, [[1, 2, 2]]) # nose.tools.assert_raises(RuntimeError, t.enroll, m, [[1, 2, 2]], 5) + + +def test_JFATrainInitialize(): + # Check that the initialization is consistent and using the rng (cf. issue #118) + + eps = 1e-10 + + # UBM GMM + ubm = GMMMachine(2, 3) + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + + ## JFA + it = JFAMachine(ubm, 2, 2, em_iterations=10) + # first round + + it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + u1 = it.U + v1 = it.V + d1 = it.D + + # second round + it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + u2 = it.U + v2 = it.V + d2 = it.D + + assert np.allclose(u1, u2, eps) + assert np.allclose(v1, v2, eps) + assert np.allclose(d1, d2, eps) + + +def test_ISVTrainInitialize(): + + # Check that the initialization is consistent and using the rng (cf. issue #118) + eps = 1e-10 + + # UBM GMM + ubm = GMMMachine(2, 3) + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + + ## ISV + it = ISVMachine(ubm, 2, em_iterations=10) + # it.rng = rng + + it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + u1 = copy.deepcopy(it.U) + d1 = copy.deepcopy(it.D) + + # second round + it.initialize(TRAINING_STATS_X, TRAINING_STATS_y) + u2 = it.U + d2 = it.D + + assert np.allclose(u1, u2, eps) + assert np.allclose(d1, d2, eps) -- GitLab