diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1daa2300fc671ed82d211795f4abe5f019b8d7fe..833ff2341e4272ee8687a877a1a39ee360be8fb9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,11 @@ repos: rev: 3.9.2 hooks: - id: flake8 + exclude: | + (?x)^( + bob/devtools/templates/setup.py| + deps/bob-devel/run_test.py + )$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: diff --git a/bob/learn/em/__init__.py b/bob/learn/em/__init__.py index 6c50a0a6c1a429550f2e377665a21b5d13431ce3..fda23d6cb6bd7992e628f634526fae1889945d66 100644 --- a/bob/learn/em/__init__.py +++ b/bob/learn/em/__init__.py @@ -1,5 +1,6 @@ import bob.extension +from .factor_analysis import ISVMachine, JFAMachine from .gmm import GMMMachine, GMMStats from .kmeans import KMeansMachine from .linear_scoring import linear_scoring # noqa: F401 @@ -29,10 +30,6 @@ def __appropriate__(*args): __appropriate__( - KMeansMachine, - GMMMachine, - GMMStats, - WCCN, - Whitening, + KMeansMachine, GMMMachine, GMMStats, WCCN, Whitening, ISVMachine, JFAMachine ) __all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9a8894d56e530e661149e7d7f7e81760cb3926 --- /dev/null +++ b/bob/learn/em/factor_analysis.py @@ -0,0 +1,2229 @@ +#!/usr/bin/env python +# @author: Tiago de Freitas Pereira + + +import functools +import logging +import operator + +import dask +import numpy as np + +from dask.array.core import Array +from dask.delayed import Delayed +from sklearn.base import BaseEstimator +from sklearn.utils import check_consistent_length +from sklearn.utils.multiclass import unique_labels + +from .gmm import GMMMachine +from .linear_scoring import linear_scoring +from .utils import check_and_persist_dask_input + +logger = logging.getLogger(__name__) + + +def is_input_dask_nested(X): + """ + Check if the input is a dask delayed or array or a (nested) list of dask + delayed or array objects. + """ + if isinstance(X, (list, tuple)): + return is_input_dask_nested(X[0]) + + if isinstance(X, (Delayed, Array)): + return True + else: + return False + + +def check_dask_input_samples_per_class(X, y): + input_is_dask = is_input_dask_nested(X) + + if input_is_dask: + n_classes = len(y) + n_samples_per_class = [len(yy) for yy in y] + else: + unique_labels_y = unique_labels(y) + n_classes = len(unique_labels_y) + n_samples_per_class = [ + sum(y == class_id) for class_id in unique_labels_y + ] + return input_is_dask, n_classes, n_samples_per_class + + +def reduce_iadd(a): + """Reduces a list by adding all elements into the first element""" + return functools.reduce(operator.iadd, a) + + +def mult_along_axis(A, B, axis): + """ + Magic function to multiply two arrays along a given axis. + Taken from https://stackoverflow.com/questions/30031828/multiply-numpy-ndarray-with-1d-array-along-a-given-axis + """ + + # ensure we're working with Numpy arrays + A = np.array(A) + B = np.array(B) + + # shape check + if axis >= A.ndim: + raise np.AxisError(axis, A.ndim) + if A.shape[axis] != B.size: + raise ValueError( + "Length of 'A' along the given axis must be the same as B.size" + ) + + # np.broadcast_to puts the new axis as the last axis, so + # we swap the given axis with the last one, to determine the + # corresponding array shape. np.swapaxes only returns a view + # of the supplied array, so no data is copied unnecessarily. + shape = np.swapaxes(A, A.ndim - 1, axis).shape + + # Broadcast to an array with the shape as above. Again, + # no data is copied, we only get a new look at the existing data. + B_brc = np.broadcast_to(B, shape) + + # Swap back the axes. As before, this only changes our "point of view". + B_brc = np.swapaxes(B_brc, A.ndim - 1, axis) + + return A * B_brc + + +class FactorAnalysisBase(BaseEstimator): + """ + Factor Analysis base class. + This class is not intended to be used directly, but rather to be inherited from. + For more information check [McCool2013]_ . + + + Parameters + ---------- + + r_U: int + Dimension of the subspace U + + r_V: int + Dimension of the subspace V + + em_iterations: int + Number of EM iterations + + relevance_factor: float + Factor analysis relevance factor + + random_state: int + random_state for the random number generator + + ubm: :py:class:`bob.learn.em.GMMMachine` + A trained UBM (Universal Background Model) or a parametrized + :py:class:`bob.learn.em.GMMMachine` to train the UBM with. If None, + `ubm_kwargs` are passed as parameters of a new + :py:class:`bob.learn.em.GMMMachine`. + """ + + def __init__( + self, + r_U, + r_V=None, + relevance_factor=4.0, + em_iterations=10, + random_state=0, + ubm=None, + ubm_kwargs=None, + **kwargs, + ): + super().__init__(**kwargs) + self.ubm = ubm + self.ubm_kwargs = ubm_kwargs + self.em_iterations = em_iterations + self.random_state = random_state + + # axis 1 dimensions of U and V + self.r_U = r_U + self.r_V = r_V + + self.relevance_factor = relevance_factor + + if ubm is not None and ubm._means is not None: + self.create_UVD() + + @property + def feature_dimension(self): + """Get the UBM Dimension""" + + # TODO: Add this on the GMMMachine class + return self.ubm.means.shape[1] + + @property + def supervector_dimension(self): + """Get the supervector dimension""" + return self.ubm.n_gaussians * self.feature_dimension + + @property + def mean_supervector(self): + """ + Returns the mean supervector + """ + return self.ubm.means.flatten() + + @property + def variance_supervector(self): + """ + Returns the variance supervector + """ + return self.ubm.variances.flatten() + + @property + def U(self): + """An alias for `_U`.""" + return self._U + + @U.setter + def U(self, value): + self._U = np.array(value) + + @property + def D(self): + """An alias for `_D`.""" + return self._D + + @D.setter + def D(self, value): + self._D = np.array(value) + + @property + def V(self): + """An alias for `_V`.""" + return self._V + + @V.setter + def V(self, value): + self._V = np.array(value) + + def estimate_number_of_classes(self, y): + """ + Estimates the number of classes given the labels + """ + + return len(unique_labels(y)) + + def initialize(self, X): + """ + Accumulating 0th and 1st order statistics. Trains the UBM if needed. + + Parameters + ---------- + X: list of numpy arrays + List of data to accumulate the statistics + y: list of ints + + Returns + ------- + + n_acc: array + (n_classes, n_gaussians) representing the accumulated 0th order statistics + + f_acc: array + (n_classes, n_gaussians, feature_dim) representing the accumulated 1st order statistics + + """ + + if self.ubm is None: + logger.info("Creating a new GMMMachine and training it.") + self.ubm = GMMMachine(**self.ubm_kwargs) + self.ubm.fit(X) + + if self.ubm._means is None: + logger.info("UBM means are None, training the UBM.") + self.ubm.fit(X) + + # Initializing the state matrix + if not hasattr(self, "_U") or not hasattr(self, "_D"): + self.create_UVD() + + def initialize_using_stats(self, ubm_projected_X, y, n_classes): + # Accumulating 0th and 1st order statistics + # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68 + + if is_input_dask_nested(ubm_projected_X): + n_acc = [ + dask.delayed(self._sum_n_statistics)(xx, yy, n_classes) + for xx, yy in zip(ubm_projected_X, y) + ] + + f_acc = [ + dask.delayed(self._sum_f_statistics)(xx, yy, n_classes) + for xx, yy in zip(ubm_projected_X, y) + ] + n_acc, f_acc = dask.compute(n_acc, f_acc) + n_acc = reduce_iadd(n_acc) + f_acc = reduce_iadd(f_acc) + else: + # 0th order stats + n_acc = self._sum_n_statistics(ubm_projected_X, y, n_classes) + # 1st order stats + f_acc = self._sum_f_statistics(ubm_projected_X, y, n_classes) + + return n_acc, f_acc + + def create_UVD(self): + """ + Create the state matrices U, V and D + + Returns + ------- + + U: (n_gaussians*feature_dimension, r_U) represents the session variability matrix (within-class variability) + + V: (n_gaussians*feature_dimension, r_V) represents the session variability matrix (between-class variability) + + D: (n_gaussians*feature_dimension) represents the client offset vector + + """ + if self.random_state is not None: + np.random.seed(self.random_state) + + U_shape = (self.supervector_dimension, self.r_U) + + # U matrix is initialized using a normal distribution + self._U = np.random.normal(scale=1.0, loc=0.0, size=U_shape) + + # D matrix is initialized as `D = sqrt(variance(UBM) / relevance_factor)` + self._D = np.sqrt(self.variance_supervector / self.relevance_factor) + + # V matrix (or between-class variation matrix) + # TODO: so far not doing JFA + if self.r_V is not None: + V_shape = (self.supervector_dimension, self.r_V) + self._V = np.random.normal(scale=1.0, loc=0.0, size=V_shape) + else: + self._V = 0 + + def _sum_n_statistics(self, X, y, n_classes): + """ + Accumulates the 0th statistics for each client + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics of each sample + + y: list of ints + List of corresponding labels + + n_classes: int + Number of classes + + Returns + ------- + n_acc: array + (n_classes, n_gaussians) representing the accumulated 0th order statistics + + """ + # 0th order stats + n_acc = np.zeros((n_classes, self.ubm.n_gaussians)) + + # Iterate for each client + for x_i, y_i in zip(X, y): + # Accumulate the 0th statistics for each class + n_acc[y_i, :] += x_i.n + + return n_acc + + def _sum_f_statistics(self, X, y, n_classes): + """ + Accumulates the 1st order statistics for each client + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + + y: list of ints + List of corresponding labels + + n_classes: int + Number of classes + + Returns + ------- + f_acc: array + (n_classes, n_gaussians, feature_dimension) representing the accumulated 1st order statistics + + """ + + # 1st order stats + f_acc = np.zeros( + ( + n_classes, + self.ubm.n_gaussians, + self.feature_dimension, + ) + ) + # Iterate for each client + for x_i, y_i in zip(X, y): + # Accumulate the 1st order statistics + f_acc[y_i, :, :] += x_i.sum_px + + return f_acc + + def _get_statistics_by_class_id(self, X, y, i): + """ + Returns the statistics for a given class + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + + y: list of ints + List of corresponding labels + + i: int + Class id to return the statistics for + """ + X = np.array(X) + return list(X[np.array(y) == i]) + + """ + Estimating U and x + """ + + def _compute_id_plus_u_prod_ih(self, x_i, UProd): + """ + Computes ( I+Ut*diag(sigma)^-1*Ni*U)^-1 + See equation (29) in [McCool2013]_ + + Parameters + ---------- + x_i: :py:class:`bob.learn.em.GMMStats` + Statistics of a single sample + + UProd: array + Matrix containing U_c.T*inv(Sigma_c) @ U_c.T + + Returns + ------- + id_plus_u_prod_ih: array + ( I+Ut*diag(sigma)^-1*Ni*U)^-1 + + """ + + n_i = x_i.n + I = np.eye(self.r_U, self.r_U) # noqa: E741 + + # TODO: make the invertion matrix function as a parameter + return np.linalg.inv(I + (UProd * n_i[:, None, None]).sum(axis=0)) + + def _compute_fn_x_ih(self, x_i, latent_z_i=None, latent_y_i=None): + """ + Computes Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i} - V*y_{i}) + Check equation (29) in [McCool2013]_ + + Parameters + ---------- + x_i: :py:class:`bob.learn.em.GMMStats` + Statistics of a single sample + + latent_z_i: array + E[z_i] for class `i` + + latent_y_i: array + E[y_i] for class `i` + + """ + f_i = x_i.sum_px + n_i = x_i.n + n_ic = np.repeat(n_i, self.feature_dimension) + V = self._V + + # N_ih*( m + D*z) + # z is zero when the computation flow comes from update_X + if latent_z_i is None: + # Fn_x_ih = N_{i,h}*(o_{i,h} - m) + fn_x_ih = f_i.flatten() - n_ic * (self.mean_supervector) + else: + # code goes here when the computation flow comes from compute_accumulators + # Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i}) + fn_x_ih = f_i.flatten() - n_ic * ( + self.mean_supervector + self._D * latent_z_i + ) + + """ + # JFA Part (eq 29) + """ + V_dot_v = V @ latent_y_i if latent_y_i is not None else 0 + fn_x_ih -= n_ic * V_dot_v if latent_y_i is not None else 0 + + return fn_x_ih + + def compute_latent_x( + self, *, X, y, n_classes, UProd, latent_y=None, latent_z=None + ): + """ + Computes a new math:`E[x]` See equation (29) in [McCool2013]_ + + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + UProd: array + Matrix containing U_c.T*inv(Sigma_c) @ U_c.T + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + Returns + ------- + Returns the new latent_x + """ + + # U.T @ inv(Sigma) - See Eq(37) + UTinvSigma = self._U.T / self.variance_supervector + + if is_input_dask_nested(X): + latent_x = [ + dask.delayed(self._compute_latent_x_per_class)( + X_i=X_i, + UProd=UProd, + UTinvSigma=UTinvSigma, + latent_y_i=latent_y[y_i] if latent_y is not None else None, + latent_z_i=latent_z[y_i] if latent_z is not None else None, + ) + for y_i, X_i in enumerate(X) + ] + latent_x = dask.compute(*latent_x) + else: + latent_x = [None] * n_classes + for y_i in unique_labels(y): + latent_x[y_i] = self._compute_latent_x_per_class( + X_i=np.array(X)[y == y_i], + UProd=UProd, + UTinvSigma=UTinvSigma, + latent_y_i=latent_y[y_i] if latent_y is not None else None, + latent_z_i=latent_z[y_i] if latent_z is not None else None, + ) + + return latent_x + + def _compute_latent_x_per_class( + self, *, X_i, UProd, UTinvSigma, latent_y_i, latent_z_i + ): + # For each sample + latent_x_i = [] + for x_i in X_i: + id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) + + fn_x_ih = self._compute_fn_x_ih( + x_i, latent_z_i=latent_z_i, latent_y_i=latent_y_i + ) + latent_x_i.append(id_plus_prod_ih @ (UTinvSigma @ fn_x_ih)) + latent_x_i = np.swapaxes(latent_x_i, 0, 1) + return latent_x_i + + def update_U(self, acc_U_A1, acc_U_A2): + """ + Update rule for U + + Parameters + ---------- + + acc_U_A1: array + Accumulated statistics for U_A1(n_gaussians, r_U, r_U) + + acc_U_A2: array + Accumulated statistics for U_A2(n_gaussians* feature_dimension, r_U) + + """ + + # Inverting A1 over the zero axis + # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy + inv_A1 = np.linalg.inv(acc_U_A1) + + # Iterating over the gaussians to update U + U_c = ( + acc_U_A2.reshape( + self.ubm.n_gaussians, self.feature_dimension, self.r_U + ) + @ inv_A1 + ) + self._U = U_c.reshape( + self.ubm.n_gaussians * self.feature_dimension, self.r_U + ) + return self._U + + def _compute_uprod(self): + """ + Computes U_c.T*inv(Sigma_c) @ U_c.T + + + https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/FABaseTrainer.cpp#L325 + """ + # UProd = (self.ubm.n_gaussians, self.r_U, self.r_U) + + Uc = self._U.reshape( + (self.ubm.n_gaussians, self.feature_dimension, self.r_U) + ) + UcT = Uc.transpose(0, 2, 1) + + sigma_c = self.ubm.variances[:, np.newaxis] + UProd = (UcT / sigma_c) @ Uc + + return UProd + + def compute_accumulators_U(self, X, y, UProd, latent_x, latent_y, latent_z): + """ + Computes the accumulators (A1 and A2) for the U matrix. + This is useful for parallelization purposes. + + The accumulators are defined as + + :math:`A_1 = \\sum\\limits_{i=1}^{I}\\sum\\limits_{h=1}^{H}N_{i,h,c}E(x_{i,h,c} x^{\\top}_{i,h,c})` + + + :math:`A_2 = \\sum\\limits_{i=1}^{I}\\sum\\limits_{h=1}^{H}N_{i,h,c}(o_{i,h} - \\mu_c -D_{c}z_{i,c} -V_{c}y_{i,c} )E[x_{i,h}]^{\\top}` + + + More information, please, check the technical notes attached + + Parameters + ---------- + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + UProd: array + Matrix containing U_c.T*inv(Sigma_c) @ U_c.T + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + Returns + ------- + acc_U_A1: + (n_gaussians, r_U, r_U) A1 accumulator + + acc_U_A2: + (n_gaussians* feature_dimension, r_U) A2 accumulator + + + """ + + # U accumulators + acc_U_A1 = np.zeros((self.ubm.n_gaussians, self.r_U, self.r_U)) + acc_U_A2 = np.zeros((self.supervector_dimension, self.r_U)) + + # Loops over all people + for y_i in set(y): + # For each session + for session_index, x_i in enumerate( + self._get_statistics_by_class_id(X, y, y_i) + ): + id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) + latent_z_i = latent_z[y_i] if latent_z is not None else None + latent_y_i = latent_y[y_i] if latent_y is not None else None + fn_x_ih = self._compute_fn_x_ih( + x_i, latent_y_i=latent_y_i, latent_z_i=latent_z_i + ) + + latent_x_i = latent_x[y_i][:, session_index] + id_plus_prod_ih += ( + latent_x_i[:, np.newaxis] @ latent_x_i[:, np.newaxis].T + ) + + acc_U_A1 += mult_along_axis( + id_plus_prod_ih[np.newaxis].repeat( + self.ubm.n_gaussians, axis=0 + ), + x_i.n, + axis=0, + ) + + acc_U_A2 += fn_x_ih[np.newaxis].T @ latent_x_i[:, np.newaxis].T + + return acc_U_A1, acc_U_A2 + + """ + Estimating D and z + """ + + def update_z(self, X, y, latent_x, latent_y, latent_z, n_acc, f_acc): + """ + Computes a new math:`E[z]` See equation (30) in [McCool2013]_ + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + Returns + ------- + Returns the new latent_z + + """ + + # Precomputing + # self._D.T / sigma + dt_inv_sigma = self._D / self.variance_supervector + # self._D.T / sigma * self._D + dt_inv_sigma_d = dt_inv_sigma * self._D + + # for each class + for y_i in set(y): + + id_plus_d_prod = self._compute_id_plus_d_prod_i( + dt_inv_sigma_d, n_acc[y_i] + ) + X_i = self._get_statistics_by_class_id(X, y, y_i) + latent_x_i = latent_x[y_i] + + latent_y_i = latent_y[y_i] if latent_y is not None else None + + fn_z_i = self._compute_fn_z_i( + X_i, latent_x_i, latent_y_i, n_acc[y_i], f_acc[y_i] + ) + latent_z[y_i] = id_plus_d_prod * dt_inv_sigma * fn_z_i + + return latent_z + + def _compute_id_plus_d_prod_i(self, dt_inv_sigma_d, n_acc_i): + """ + Computes: (I+Dt*diag(sigma)^-1*Ni*D)^-1 + See equation (31) in [McCool2013]_ + + Parameters + ---------- + + i: int + Class id + + dt_inv_sigma_d: array + Matrix representing `D.T / sigma` + + """ + + tmp_CD = np.repeat(n_acc_i, self.feature_dimension) + id_plus_d_prod = np.ones(tmp_CD.shape) + dt_inv_sigma_d * tmp_CD + return 1 / id_plus_d_prod + + def _compute_fn_z_i(self, X_i, latent_x_i, latent_y_i, n_acc_i, f_acc_i): + """ + Compute Fn_z_i = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - V*y_{i} - U*x_{i,h}) (Normalized first order statistics) + + Parameters + ---------- + i: int + Class id + + """ + + U = self._U + V = self._V + + m = self.mean_supervector + + tmp_CD = np.repeat(n_acc_i, self.feature_dimension) + + # JFA session part + V_dot_v = V @ latent_y_i if latent_y_i is not None else 0 + + # m_cache_Fn_z_i = Fi - m_tmp_CD * (m + m_tmp_CD_b); // Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - V*y_{i}) + fn_z_i = f_acc_i.flatten() - tmp_CD * (m + V_dot_v) + + # Looping over the sessions + for session_id in range(len(X_i)): + n_i = X_i[session_id].n + tmp_CD = np.repeat(n_i, self.feature_dimension) + x_i_h = latent_x_i[:, session_id] + + fn_z_i -= tmp_CD * (U @ x_i_h) + + return fn_z_i + + def compute_accumulators_D( + self, X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ): + """ + Compute the accumulators for the D matrix + + The accumulators are defined as + + :math:`A_1 = \\sum\\limits_{i=1}^{I}E[z_{i,c}z^{\\top}_{i,c}]` + + + :math:`A_2 = \\sum\\limits_{i=1}^{I} \\Bigg[\\sum\\limits_{h=1}^{H}N_{i,h,c}(o_{i,h} - \\mu_c -U_{c}x_{i,h,c} -V_{c}y_{i,c} )\\Bigg]E[z_{i}]^{\\top}` + + + More information, please, check the technical notes attached + + + Parameters + ---------- + + X: array + Input data + + y: array + Class labels + + latent_z: array + E(z) latent variable + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + Returns + ------- + acc_D_A1: + (n_gaussians* feature_dimension) A1 accumulator + + acc_D_A2: + (n_gaussians* feature_dimension) A2 accumulator + + """ + + acc_D_A1 = np.zeros((self.supervector_dimension,)) + acc_D_A2 = np.zeros((self.supervector_dimension,)) + + # Precomputing + # self._D.T / sigma + dt_inv_sigma = self._D / self.variance_supervector + # self._D.T / sigma * self._D + dt_inv_sigma_d = dt_inv_sigma * self._D + + # Loops over all people + for y_i in set(y): + + id_plus_d_prod = self._compute_id_plus_d_prod_i( + dt_inv_sigma_d, n_acc[y_i] + ) + X_i = self._get_statistics_by_class_id(X, y, y_i) + latent_x_i = latent_x[y_i] + + latent_y_i = latent_y[y_i] if latent_y is not None else None + + fn_z_i = self._compute_fn_z_i( + X_i, latent_x_i, latent_y_i, n_acc[y_i], f_acc[y_i] + ) + + tmp_CD = np.repeat(n_acc[y_i], self.feature_dimension) + acc_D_A1 += ( + id_plus_d_prod + latent_z[y_i] * latent_z[y_i] + ) * tmp_CD + acc_D_A2 += fn_z_i * latent_z[y_i] + + return acc_D_A1, acc_D_A2 + + def initialize_XYZ(self, n_samples_per_class): + """ + Initialize E[x], E[y], E[z] state variables + + Eq. (38) + latent_z = (n_classes, supervector_dimension) + + + Eq. (37) + latent_y = (n_classes, r_V) or None + + Eq. (36) + latent_x = (n_classes, r_U, n_sessions) + + """ + + # x (Eq. 36) + # (n_classes, r_U, n_samples ) + latent_x = [] + for n_s in n_samples_per_class: + latent_x.append(np.zeros((self.r_U, n_s))) + + n_classes = len(n_samples_per_class) + latent_y = ( + np.zeros((n_classes, self.r_V)) + if self.r_V and self.r_V > 0 + else None + ) + + latent_z = np.zeros((n_classes, self.supervector_dimension)) + + return latent_x, latent_y, latent_z + + """ + Estimating V and y + """ + + def update_y( + self, + *, + X, + y, + n_classes, + VProd, + latent_x, + latent_y, + latent_z, + n_acc, + f_acc, + ): + """ + Computes a new math:`E[y]` See equation (30) in [McCool2013]_ + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + n_classes: int + Number of classes + + VProd: array + Matrix representing V_c.T*inv(Sigma_c) @ V_c.T + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + """ + # V.T / sigma + VTinvSigma = self._V.T / self.variance_supervector + + if is_input_dask_nested(X): + latent_y = [ + dask.delayed(self._latent_y_per_class)( + X_i=X_i, + n_acc_i=n_acc[label], + f_acc_i=f_acc[label], + VProd=VProd, + VTinvSigma=VTinvSigma, + latent_x_i=latent_x[label], + latent_z_i=latent_z[label], + ) + for label, X_i in enumerate(X) + ] + latent_y = dask.compute(*latent_y) + else: + # Loops over the labels + for label in range(n_classes): + X_i = self._get_statistics_by_class_id(X, y, label) + latent_y[label] = self._latent_y_per_class( + X_i=X_i, + n_acc_i=n_acc[label], + f_acc_i=f_acc[label], + VProd=VProd, + VTinvSigma=VTinvSigma, + latent_x_i=latent_x[label], + latent_z_i=latent_z[label], + ) + return latent_y + + def _latent_y_per_class( + self, + *, + X_i, + n_acc_i, + f_acc_i, + VProd, + VTinvSigma, + latent_x_i, + latent_z_i, + ): + id_plus_v_prod_i = self._compute_id_plus_vprod_i(n_acc_i, VProd) + fn_y_i = self._compute_fn_y_i( + X_i, + latent_x_i, + latent_z_i, + n_acc_i, + f_acc_i, + ) + latent_y = (VTinvSigma @ fn_y_i) @ id_plus_v_prod_i + return latent_y + + def _compute_id_plus_vprod_i(self, n_acc_i, VProd): + """ + Computes: (I+Vt*diag(sigma)^-1*Ni*V)^-1 (see Eq. (30) in [McCool2013]_) + + Parameters + ---------- + + n_acc_i: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + VProd: array + Matrix representing V_c.T*inv(Sigma_c) @ V_c.T + + """ + + I = np.eye(self.r_V, self.r_V) # noqa: E741 + + # TODO: make the invertion matrix function as a parameter + return np.linalg.inv(I + (VProd * n_acc_i[:, None, None]).sum(axis=0)) + + def _compute_vprod(self): + """ + Computes V_c.T*inv(Sigma_c) @ V_c.T + + + https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/FABaseTrainer.cpp#L193 + """ + + Vc = self._V.reshape( + (self.ubm.n_gaussians, self.feature_dimension, self.r_V) + ) + VcT = Vc.transpose(0, 2, 1) + + sigma_c = self.ubm.variances[:, np.newaxis] + VProd = (VcT / sigma_c) @ Vc + + return VProd + + def compute_accumulators_V( + self, X, y, VProd, n_acc, f_acc, latent_x, latent_y, latent_z + ): + """ + Computes the accumulators for the update of V matrix + The accumulators are defined as + + :math:`A_1 = \\sum\\limits_{i=1}^{I}E[y_{i,c}y^{\\top}_{i,c}]` + + + :math:`A_2 = \\sum\\limits_{i=1}^{I} \\Bigg[\\sum\\limits_{h=1}^{H}N_{i,h,c}(o_{i,h} - \\mu_c -U_{c}x_{i,h,c} -D_{c}z_{i,c} )\\Bigg]E[y_{i}]^{\\top}` + + + More information, please, check the technical notes attached + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of ints + List of corresponding labels + + VProd: array + Matrix representing V_c.T*inv(Sigma_c) @ V_c.T + + n_acc: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + + Returns + ------- + + acc_V_A1: + (n_gaussians, r_V, r_V) A1 accumulator + + acc_V_A2: + (n_gaussians* feature_dimension, r_V) A2 accumulator + + """ + + # U accumulators + acc_V_A1 = np.zeros((self.ubm.n_gaussians, self.r_V, self.r_V)) + acc_V_A2 = np.zeros((self.supervector_dimension, self.r_V)) + + # Loops over all people + for i in set(y): + n_acc_i = n_acc[i] + f_acc_i = f_acc[i] + X_i = self._get_statistics_by_class_id(X, y, i) + latent_x_i = latent_x[i] + latent_y_i = latent_y[i] + latent_z_i = latent_z[i] + + # Computing A1 accumulator: \sum_{i=1}^{N}(E(y_i_c @ y_i_c.T)) + id_plus_prod_v_i = self._compute_id_plus_vprod_i(n_acc_i, VProd) + id_plus_prod_v_i += ( + latent_y_i[:, np.newaxis] @ latent_y_i[:, np.newaxis].T + ) + + acc_V_A1 += mult_along_axis( + id_plus_prod_v_i[np.newaxis].repeat( + self.ubm.n_gaussians, axis=0 + ), + n_acc_i, + axis=0, + ) + + # Computing A2 accumulator: \sum_{i=1}^{N}( \sum_{h=1}^{H}(N_i_h_c (o_i_h, - m_c - D_c*z_i_c - U_c*x_i_h_c))@ E(y_i).T ) + fn_y_i = self._compute_fn_y_i( + X_i, + latent_x_i=latent_x_i, + latent_z_i=latent_z_i, + n_acc_i=n_acc_i, + f_acc_i=f_acc_i, + ) + + acc_V_A2 += fn_y_i[np.newaxis].T @ latent_y_i[:, np.newaxis].T + + return acc_V_A1, acc_V_A2 + + def _compute_fn_y_i(self, X_i, latent_x_i, latent_z_i, n_acc_i, f_acc_i): + """ + // Compute Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - D*z_{i} - U*x_{i,h}) (Normalized first order statistics) + See equation (30) in [McCool2013]_ + + Parameters + ---------- + + X_i: list of :py:class:`bob.learn.em.GMMStats` + List of statistics for a class + + latent_x_i: array + E(x_i) latent variable + + latent_z_i: array + E(z_i) latent variable + + n_acc_i: array + Accumulated 0th order statistics for each class (math:`N_{i}`) + + f_acc_i: array + Accumulated 1st order statistics for each class (math:`F_{i}`) + + + """ + + U = self._U + D = self._D # Not doing the JFA + + m = self.mean_supervector + + # y = self.y[i] # Not doing JFA + + tmp_CD = np.repeat(n_acc_i, self.feature_dimension) + + fn_y_i = f_acc_i.flatten() - tmp_CD * ( + m - D * latent_z_i + ) # Fn_yi = sum_{sessions h}(N_{i,h}*(o_{i,h} - m - D*z_{i}) + + # Looping over the sessions of a ;ane; + for session_id in range(len(X_i)): + n_i = X_i[session_id].n + U_dot_x = U @ latent_x_i[:, session_id] + tmp_CD = np.repeat(n_i, self.feature_dimension) + fn_y_i -= tmp_CD * U_dot_x + + return fn_y_i + + """ + Scoring + """ + + def estimate_x(self, X): + + id_plus_us_prod_inv = self._compute_id_plus_us_prod_inv(X) + fn_x = self._compute_fn_x(X) + + # UtSigmaInv * Fn_x = Ut*diag(sigma)^-1 * N*(o - m) + ut_inv_sigma = self._U.T / self.variance_supervector + + return id_plus_us_prod_inv @ (ut_inv_sigma @ fn_x) + + def estimate_ux(self, X): + x = self.estimate_x(X) + return self.U @ x + + def _compute_id_plus_us_prod_inv(self, X_i): + """ + Computes (Id + U^T.Sigma^-1.U.N_{i,h}.U)^-1 = + + Parameters + ---------- + + X_i: list of :py:class:`bob.learn.em.GMMStats` + List of statistics for a class + """ + I = np.eye(self.r_U, self.r_U) # noqa: E741 + + Uc = self._U.reshape( + (self.ubm.n_gaussians, self.feature_dimension, self.r_U) + ) + + UcT = np.transpose(Uc, axes=(0, 2, 1)) + + sigma_c = self.ubm.variances[:, np.newaxis] + + n_i_c = np.expand_dims(X_i.n[:, np.newaxis], axis=2) + + id_plus_us_prod_inv = I + (((UcT / sigma_c) @ Uc) * n_i_c).sum(axis=0) + id_plus_us_prod_inv = np.linalg.inv(id_plus_us_prod_inv) + + return id_plus_us_prod_inv + + def _compute_fn_x(self, X_i): + """ + Compute Fn_x = sum_{sessions h}(N*(o - m) (Normalized first order statistics) + + Parameters + ---------- + + X_i: list of :py:class:`bob.learn.em.GMMStats` + List of statistics for a class + + """ + + n = X_i.n[:, np.newaxis] + f = X_i.sum_px + + fn_x = f - self.ubm.means * n + + return fn_x.flatten() + + def score(self, model, data): + """ + Computes the ISV score using a numpy array as input + + Parameters + ---------- + latent_z : numpy.ndarray + Latent representation of the client (E[z_i]) + + data : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be scored + + Returns + ------- + score : float + The linear scored + + """ + + return self.score_using_stats(model, self.ubm.acc_stats(data)) + + def fit(self, X, y): + + input_is_dask, X = check_and_persist_dask_input(X, persist=False) + y = np.squeeze(np.asarray(y)) + check_consistent_length(X, y) + + self.initialize(X) + + if input_is_dask: + # split the X array based on the classes + X_new, y_new = [], [] + for class_id in unique_labels(y): + class_indices = y == class_id + X_new.append(X[class_indices]) + y_new.append(y[class_indices]) + X, y = X_new, y_new + del X_new, y_new + + stats = [ + dask.delayed(self.ubm.stats_per_sample)(xx).persist() + for xx in X + ] + else: + stats = self.ubm.stats_per_sample(X) + + del X + self.fit_using_stats(stats, y) + return self + + +class ISVMachine(FactorAnalysisBase): + """ + Implements the Intersession Variability Modelling hypothesis on top of GMMs + + Inter-Session Variability (ISV) modeling is a session variability modeling technique built on top of the Gaussian mixture modeling approach. + It hypothesizes that within-class variations are embedded in a linear subspace in the GMM means subspace and these variations can be suppressed + by an offset w.r.t each mean during the MAP adaptation. + For more information check [McCool2013]_ + + Parameters + ---------- + + r_U: int + Dimension of the subspace U + + em_iterations: int + Number of EM iterations + + relevance_factor: float + Factor analysis relevance factor + + random_state: int + random_state for the random number generator + + ubm: :py:class:`bob.learn.em.GMMMachine` or None + A trained UBM (Universal Background Model). If None, the UBM is trained with + a new :py:class:`bob.learn.em.GMMMachine` when fit is called, with `ubm_kwargs` + as parameters. + + """ + + def __init__( + self, + r_U, + em_iterations=10, + relevance_factor=4.0, + random_state=0, + ubm=None, + ubm_kwargs=None, + **kwargs, + ): + super().__init__( + r_U=r_U, + relevance_factor=relevance_factor, + em_iterations=em_iterations, + random_state=random_state, + ubm=ubm, + ubm_kwargs=ubm_kwargs, + **kwargs, + ) + + def e_step(self, X, y, n_samples_per_class, n_acc, f_acc): + """ + E-step of the EM algorithm + """ + # self.initialize_XYZ(y) + UProd = self._compute_uprod() + _, _, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) + latent_y = None + + latent_x = self.compute_latent_x( + X=X, + y=y, + n_classes=len(n_samples_per_class), + UProd=UProd, + ) + latent_z = self.update_z( + X=X, + y=y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + acc_U_A1, acc_U_A2 = self.compute_accumulators_U( + X, y, UProd, latent_x, latent_y, latent_z + ) + + return acc_U_A1, acc_U_A2 + + def m_step(self, acc_U_A1_acc_U_A2_list): + """ + ISV M-step. + This updates `U` matrix + + Parameters + ---------- + + acc_U_A1: array + Accumulated statistics for U_A1(n_gaussians, r_U, r_U) + + acc_U_A2: array + Accumulated statistics for U_A2(n_gaussians* feature_dimension, r_U) + + """ + acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list] + acc_U_A1 = reduce_iadd(acc_U_A1) + + acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] + acc_U_A2 = reduce_iadd(acc_U_A2) + + return self.update_U(acc_U_A1, acc_U_A2) + + def fit_using_stats(self, X, y): + """ + Trains the U matrix (session variability matrix) + + Parameters + ---------- + X : numpy.ndarray + Nxd features of N GMM statistics + y : numpy.ndarray + The input labels, a 1D numpy array of shape (number of samples, ) + + Returns + ------- + self : object + Returns self. + + """ + ( + input_is_dask, + n_classes, + n_samples_per_class, + ) = check_dask_input_samples_per_class(X, y) + + n_acc, f_acc = self.initialize_using_stats(X, y, n_classes) + + for i in range(self.em_iterations): + logger.info("U Training: Iteration %d", i + 1) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step)(e_step_output) + self._U = dask.compute(delayed_em_step)[0] + else: + e_step_output = self.e_step( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + self.m_step([e_step_output]) + + return self + + def transform(self, X): + ubm_projected_X = self.ubm.acc_stats(X) + return self.estimate_ux(ubm_projected_X) + + def enroll_using_stats(self, X, iterations=1): + """ + Enrolls a new client + In ISV, the enrolment is defined as: :math:`m + Dz` with the latent variables `z` + representing the enrolled model. + + Parameters + ---------- + X : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be enrolled + + + iterations : int + Number of iterations to perform + + Returns + ------- + self : object + z + + """ + # We have only one class for enrollment + y = list(np.zeros(len(X), dtype=np.int32)) + n_acc = self._sum_n_statistics(X, y=y, n_classes=1) + f_acc = self._sum_f_statistics(X, y=y, n_classes=1) + + UProd = self._compute_uprod() + _, _, latent_z = self.initialize_XYZ(n_samples_per_class=[len(X)]) + latent_y = None + for i in range(iterations): + logger.info("Enrollment: Iteration %d", i + 1) + latent_x = self.compute_latent_x( + X=X, + y=y, + n_classes=1, + UProd=UProd, + latent_y=latent_y, + latent_z=latent_z, + ) + latent_z = self.update_z( + X=X, + y=y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + + return latent_z + + def enroll(self, X, iterations=1): + """ + Enrolls a new client using a numpy array as input + + Parameters + ---------- + X : array + features to be enrolled + + iterations : int + Number of iterations to perform + + Returns + ------- + self : object + z + + """ + return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations) + + def score_using_stats(self, latent_z, data): + """ + Computes the ISV score + + Parameters + ---------- + latent_z : numpy.ndarray + Latent representation of the client (E[z_i]) + + data : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be scored + + Returns + ------- + score : float + The linear scored + + """ + x = self.estimate_x(data) + Ux = self._U @ x + + # TODO: I don't know why this is not the enrolled model + # Here I am just reproducing the C++ implementation + # m + Dz + z = self.D * latent_z + self.mean_supervector + + return linear_scoring( + z.reshape((self.ubm.n_gaussians, self.feature_dimension)), + self.ubm, + data, + Ux.reshape((self.ubm.n_gaussians, self.feature_dimension)), + frame_length_normalization=True, + )[0] + + +class JFAMachine(FactorAnalysisBase): + """ + Joint Factor Analysis (JFA) is an extension of ISV. Besides the + within-class assumption (modeled with :math:`U`), it also hypothesize that + between class variations are embedded in a low rank rectangular matrix + :math:`V`. In the supervector notation, this modeling has the following shape: + :math:`\\mu_{i, j} = m + Ux_{i, j} + Vy_{i} + D_z{i}`. + + For more information check [McCool2013]_ + + Parameters + ---------- + + ubm: :py:class:`bob.learn.em.GMMMachine` + A trained UBM (Universal Background Model) + + r_U: int + Dimension of the subspace U + + r_V: int + Dimension of the subspace V + + em_iterations: int + Number of EM iterations + + relevance_factor: float + Factor analysis relevance factor + + random_state: int + random_state for the random number generator + + """ + + def __init__( + self, + r_U, + r_V, + em_iterations=10, + relevance_factor=4.0, + random_state=0, + ubm=None, + ubm_kwargs=None, + **kwargs, + ): + super().__init__( + ubm=ubm, + r_U=r_U, + r_V=r_V, + relevance_factor=relevance_factor, + em_iterations=em_iterations, + random_state=random_state, + ubm_kwargs=ubm_kwargs, + **kwargs, + ) + + def e_step_v(self, X, y, n_samples_per_class, n_acc, f_acc): + """ + ISV E-step for the V matrix. + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + n_classes: int + Number of classes + + n_acc: array + Accumulated 0th-order statistics + + f_acc: array + Accumulated 1st-order statistics + + + Returns + ---------- + + acc_V_A1: array + Accumulated statistics for V_A1(n_gaussians, r_V, r_V) + + acc_V_A2: array + Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) + + """ + + VProd = self._compute_vprod() + + latent_x, latent_y, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) + + # UPDATE Y, X AND FINALLY Z + + n_classes = len(n_samples_per_class) + latent_y = self.update_y( + X=X, + y=y, + n_classes=n_classes, + VProd=VProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + + acc_V_A1, acc_V_A2 = self.compute_accumulators_V( + X=X, + y=y, + VProd=VProd, + n_acc=n_acc, + f_acc=f_acc, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + ) + + return acc_V_A1, acc_V_A2 + + def m_step_v(self, acc_V_A1_acc_V_A2_list): + """ + `V` Matrix M-step. + This updates the `V` matrix + + Parameters + ---------- + + acc_V_A1: array + Accumulated statistics for V_A1(n_gaussians, r_V, r_V) + + acc_V_A2: array + Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) + + """ + acc_V_A1 = [acc[0] for acc in acc_V_A1_acc_V_A2_list] + acc_V_A1 = reduce_iadd(acc_V_A1) + + acc_V_A2 = [acc[1] for acc in acc_V_A1_acc_V_A2_list] + acc_V_A2 = reduce_iadd(acc_V_A2) + + # Inverting A1 over the zero axis + # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy + inv_A1 = np.linalg.inv(acc_V_A1) + + V_c = ( + acc_V_A2.reshape( + (self.ubm.n_gaussians, self.feature_dimension, self.r_V) + ) + @ inv_A1 + ) + self._V = V_c.reshape( + (self.ubm.n_gaussians * self.feature_dimension, self.r_V) + ) + return self._V + + def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc): + """ + Compute for the last time `E[y]` + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + n_classes: int + Number of classes + + n_acc: array + Accumulated 0th-order statistics + + f_acc: array + Accumulated 1st-order statistics + + Returns + ------- + latent_y: array + E[y] + + """ + VProd = self._compute_vprod() + + latent_x, latent_y, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) + + # UPDATE Y, X AND FINALLY Z + + n_classes = len(n_samples_per_class) + latent_y = self.update_y( + X=X, + y=y, + n_classes=n_classes, + VProd=VProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + return latent_y + + def e_step_u(self, X, y, n_samples_per_class, latent_y): + """ + ISV E-step for the U matrix. + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + latent_y: array + E(y) latent variable + + + Returns + ---------- + + acc_U_A1: array + Accumulated statistics for U_A1(n_gaussians, r_U, r_U) + + acc_U_A2: array + Accumulated statistics for U_A2(n_gaussians* feature_dimension, r_U) + + """ + # self.initialize_XYZ(y) + UProd = self._compute_uprod() + latent_x, _, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) + + n_classes = len(n_samples_per_class) + latent_x = self.compute_latent_x( + X=X, + y=y, + n_classes=n_classes, + UProd=UProd, + latent_y=latent_y, + ) + + acc_U_A1, acc_U_A2 = self.compute_accumulators_U( + X=X, + y=y, + UProd=UProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + ) + + return acc_U_A1, acc_U_A2 + + def m_step_u(self, acc_U_A1_acc_U_A2_list): + """ + `U` Matrix M-step. + This updates the `U` matrix + + Parameters + ---------- + + acc_V_A1: array + Accumulated statistics for V_A1(n_gaussians, r_V, r_V) + + acc_V_A2: array + Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) + + """ + acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list] + acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] + + acc_U_A1 = reduce_iadd(acc_U_A1) + acc_U_A2 = reduce_iadd(acc_U_A2) + + return self.update_U(acc_U_A1, acc_U_A2) + + def finalize_u( + self, + X, + y, + n_samples_per_class, + latent_y, + ): + """ + Compute for the last time `E[x]` + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + n_classes: int + Number of classes + + latent_y: array + E[y] latent variable + + Returns + ------- + latent_x: array + E[x] + """ + + UProd = self._compute_uprod() + latent_x, _, _ = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) + + n_classes = len(n_samples_per_class) + latent_x = self.compute_latent_x( + X=X, + y=y, + n_classes=n_classes, + UProd=UProd, + latent_y=latent_y, + ) + + return latent_x + + def e_step_d( + self, X, y, n_samples_per_class, latent_x, latent_y, n_acc, f_acc + ): + """ + ISV E-step for the U matrix. + + Parameters + ---------- + + X: list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + y: list of int + List of labels + + n_classes: int + Number of classes + + latent_x: array + E(x) latent variable + + latent_y: array + E(y) latent variable + + latent_z: array + E(z) latent variable + + n_acc: array + Accumulated 0th-order statistics + + f_acc: array + Accumulated 1st-order statistics + + + Returns + ---------- + + acc_D_A1: array + Accumulated statistics for D_A1(n_gaussians* feature_dimension, ) + + acc_D_A2: array + Accumulated statistics for D_A2(n_gaussians* feature_dimension, ) + + """ + _, _, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) + + latent_z = self.update_z( + X, + y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + + acc_D_A1, acc_D_A2 = self.compute_accumulators_D( + X, y, latent_x, latent_y, latent_z, n_acc, f_acc + ) + + return acc_D_A1, acc_D_A2 + + def m_step_d(self, acc_D_A1_acc_D_A2_list): + """ + `D` Matrix M-step. + This updates the `D` matrix + + Parameters + ---------- + + acc_D_A1: array + Accumulated statistics for D_A1(n_gaussians* feature_dimension, ) + + acc_D_A2: array + Accumulated statistics for D_A2(n_gaussians* feature_dimension, ) + + """ + acc_D_A1 = [acc[0] for acc in acc_D_A1_acc_D_A2_list] + acc_D_A2 = [acc[1] for acc in acc_D_A1_acc_D_A2_list] + + acc_D_A1 = reduce_iadd(acc_D_A1) + acc_D_A2 = reduce_iadd(acc_D_A2) + + self._D = acc_D_A2 / acc_D_A1 + return self._D + + def enroll_using_stats(self, X, iterations=1): + """ + Enrolls a new client. + In JFA the enrolment is defined as: :math:`m + Vy + Dz` with the latent variables `y` and `z` + representing the enrolled model. + + Parameters + ---------- + X : list of :py:class:`bob.learn.em.GMMStats` + List of statistics + + iterations : int + Number of iterations to perform + + Returns + ------- + self : array + z, y latent variables + + """ + # We have only one class for enrollment + y = list(np.zeros(len(X), dtype=np.int32)) + n_acc = self._sum_n_statistics(X, y=y, n_classes=1) + f_acc = self._sum_f_statistics(X, y=y, n_classes=1) + + UProd = self._compute_uprod() + VProd = self._compute_vprod() + latent_x, latent_y, latent_z = self.initialize_XYZ( + n_samples_per_class=[len(X)] + ) + + for i in range(iterations): + logger.info("Enrollment: Iteration %d", i + 1) + latent_y = self.update_y( + X=X, + y=y, + n_classes=1, + VProd=VProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + latent_x = self.compute_latent_x( + X=X, + y=y, + n_classes=1, + UProd=UProd, + latent_y=latent_y, + latent_z=latent_z, + ) + latent_z = self.update_z( + X=X, + y=y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + + # The latent variables are wrapped in to 2axis arrays + return latent_y[0], latent_z[0] + + def enroll(self, X, iterations=1): + """ + Enrolls a new client using a numpy array as input + + Parameters + ---------- + X : array + features to be enrolled + + iterations : int + Number of iterations to perform + + Returns + ------- + self : object + z + + """ + return self.enroll_using_stats([self.ubm.acc_stats(X)], iterations) + + def fit_using_stats(self, X, y): + """ + Trains the U matrix (session variability matrix) + + Parameters + ---------- + X : numpy.ndarray + Nxd features of N GMM statistics + y : numpy.ndarray + The input labels, a 1D numpy array of shape (number of samples, ) + + Returns + ------- + self : object + Returns self. + + """ + + # In case those variables are already set + if ( + not hasattr(self, "_U") + or not hasattr(self, "_V") + or not hasattr(self, "_D") + ): + self.create_UVD() + + ( + input_is_dask, + n_classes, + n_samples_per_class, + ) = check_dask_input_samples_per_class(X, y) + + n_acc, f_acc = self.initialize_using_stats(X, y, n_classes=n_classes) + + # Updating V + for i in range(self.em_iterations): + logger.info("V Training: Iteration %d", i + 1) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step_v)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step_v)(e_step_output) + self._V = dask.compute(delayed_em_step)[0] + else: + e_step_output = self.e_step_v( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + self.m_step_v([e_step_output]) + latent_y = self.finalize_v( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + + # Updating U + for i in range(self.em_iterations): + logger.info("U Training: Iteration %d", i + 1) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step_u)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + latent_y=latent_y, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step_u)(e_step_output) + self._U = dask.compute(delayed_em_step)[0] + else: + e_step_output = self.e_step_u( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + latent_y=latent_y, + ) + self.m_step_u([e_step_output]) + + latent_x = self.finalize_u( + X=X, y=y, n_samples_per_class=n_samples_per_class, latent_y=latent_y + ) + + # Updating D + for i in range(self.em_iterations): + logger.info("D Training: Iteration %d", i + 1) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step_d)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + latent_x=latent_x, + latent_y=latent_y, + n_acc=n_acc, + f_acc=f_acc, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step_d)(e_step_output) + self._D = dask.compute(delayed_em_step)[0] + else: + e_step_output = self.e_step_d( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + latent_x=latent_x, + latent_y=latent_y, + n_acc=n_acc, + f_acc=f_acc, + ) + self.m_step_d([e_step_output]) + + return self + + def score_using_stats(self, model, data): + """ + Computes the JFA score + + Parameters + ---------- + latent_z : numpy.ndarray + Latent representation of the client (E[z_i]) + + data : list of :py:class:`bob.learn.em.GMMStats` + List of statistics to be scored + + Returns + ------- + score : float + The linear scored + + """ + latent_y = model[0] + latent_z = model[1] + + x = self.estimate_x(data) + Ux = self._U @ x + + # TODO: I don't know why this is not the enrolled model + # Here I am just reproducing the C++ implementation + # m + Vy + Dz + zy = self.V @ latent_y + self.D * latent_z + self.mean_supervector + + return linear_scoring( + zy.reshape((self.ubm.n_gaussians, self.feature_dimension)), + self.ubm, + data, + Ux.reshape((self.ubm.n_gaussians, self.feature_dimension)), + frame_length_normalization=True, + )[0] diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 2d5b0b9340e69dda2f9f585ebc9c16c83d006d64..988ef3a5b512bb6b4075e6886816f96f22b03edb 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -18,11 +18,8 @@ import numpy as np from h5py import File as HDF5File from sklearn.base import BaseEstimator -from .kmeans import ( - KMeansMachine, - array_to_delayed_list, - check_and_persist_dask_input, -) +from .kmeans import KMeansMachine +from .utils import array_to_delayed_list, check_and_persist_dask_input logger = logging.getLogger(__name__) @@ -135,16 +132,14 @@ def e_step(data, machine): # Count of samples [int] statistics.t += data.shape[0] # Responsibilities [array of shape (n_gaussians,)] - statistics.n = statistics.n + responsibility.sum(axis=-1) + statistics.n += responsibility.sum(axis=-1) for i in range(n_gaussians): # p * x [array of shape (n_gaussians, n_samples, n_features)] px = responsibility[i, :, None] * data # First order stats [array of shape (n_gaussians, n_features)] - statistics.sum_px[i] = statistics.sum_px[i] + np.sum(px, axis=0) + statistics.sum_px[i] += np.sum(px, axis=0) # Second order stats [array of shape (n_gaussians, n_features)] - statistics.sum_pxx[i] = statistics.sum_pxx[i] + np.sum( - px * data, axis=0 - ) + statistics.sum_pxx[i] += np.sum(px * data, axis=0) return statistics @@ -855,12 +850,17 @@ class GMMMachine(BaseEstimator): ) return self - def transform(self, X, **kwargs): + def acc_stats(self, X): """Returns the statistics for `X`.""" - return e_step( - data=X, - machine=self, - ) + # we need this because sometimes the transform function gets overridden + return e_step(data=X, machine=self) + + def transform(self, X): + """Returns the statistics for `X`.""" + return self.acc_stats(X) + + def stats_per_sample(self, X): + return [e_step(data=xx, machine=self) for xx in X] def _more_tags(self): return { diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index 8a17af323e3574d9e17bd991f26a6756f290cc3b..61af76ec9a54dce65c43d7439ecbbb8a78b61031 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -15,6 +15,8 @@ import scipy.spatial.distance from dask_ml.cluster.k_means import k_init from sklearn.base import BaseEstimator +from .utils import array_to_delayed_list, check_and_persist_dask_input + logger = logging.getLogger(__name__) @@ -171,32 +173,6 @@ def reduce_indices_means_vars(stats): return variances, weights -def check_and_persist_dask_input(data): - # check if input is a dask array. If so, persist and rebalance data - input_is_dask = False - if isinstance(data, da.Array): - data: da.Array = data.persist() - input_is_dask = True - # if there is a dask distributed client, rebalance data - try: - client = dask.distributed.Client.current() - client.rebalance() - except ValueError: - pass - - else: - data = np.asarray(data) - return input_is_dask, data - - -def array_to_delayed_list(data, input_is_dask): - # If input is a dask array, convert to delayed chunks - if input_is_dask: - data = data.to_delayed().ravel().tolist() - logger.debug(f"Got {len(data)} chunks.") - return data - - class KMeansMachine(BaseEstimator): """Stores the k-means clusters parameters (centroid of each cluster). diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..596a6a9fe2b75fa52b062e6286d4589e5ccff36d --- /dev/null +++ b/bob/learn/em/test/test_factor_analysis.py @@ -0,0 +1,584 @@ +#!/usr/bin/env python +# Laurent El Shafey <Laurent.El-Shafey@idiap.ch> +# Tiago Freitas Pereira <tiago.pereira@idiap.ch> +# Amir Mohammadi <amir.mohammadi@idiap.ch> + +import copy + +import numpy as np + +from bob.learn.em import GMMMachine, GMMStats, ISVMachine, JFAMachine + +from .test_gmm import multiprocess_dask_client +from .test_kmeans import to_dask_array, to_numpy + +# Define Training set and initial values for tests +F1 = np.array( + [ + 0.3833, + 0.4516, + 0.6173, + 0.2277, + 0.5755, + 0.8044, + 0.5301, + 0.9861, + 0.2751, + 0.0300, + 0.2486, + 0.5357, + ] +).reshape((6, 2)) +F2 = np.array( + [ + 0.0871, + 0.6838, + 0.8021, + 0.7837, + 0.9891, + 0.5341, + 0.0669, + 0.8854, + 0.9394, + 0.8990, + 0.0182, + 0.6259, + ] +).reshape((6, 2)) +F = [F1, F2] + +N1 = np.array([0.1379, 0.1821, 0.2178, 0.0418]).reshape((2, 2)) +N2 = np.array([0.1069, 0.9397, 0.6164, 0.3545]).reshape((2, 2)) +N = [N1, N2] + +gs11 = GMMStats(2, 3) +gs11.n = N1[:, 0] +gs11.sum_px = F1[:, 0].reshape(2, 3) +gs12 = GMMStats(2, 3) +gs12.n = N1[:, 1] +gs12.sum_px = F1[:, 1].reshape(2, 3) + +gs21 = GMMStats(2, 3) +gs21.n = N2[:, 0] +gs21.sum_px = F2[:, 0].reshape(2, 3) +gs22 = GMMStats(2, 3) +gs22.n = N2[:, 1] +gs22.sum_px = F2[:, 1].reshape(2, 3) + +TRAINING_STATS_X = [gs11, gs12, gs21, gs22] +TRAINING_STATS_y = [0, 0, 1, 1] +UBM_MEAN = np.array([0.1806, 0.0451, 0.7232, 0.3474, 0.6606, 0.3839]) +UBM_VAR = np.array([0.6273, 0.0216, 0.9106, 0.8006, 0.7458, 0.8131]) +M_d = np.array([0.4106, 0.9843, 0.9456, 0.6766, 0.9883, 0.7668]) +M_v = np.array( + [ + 0.3367, + 0.4116, + 0.6624, + 0.6026, + 0.2442, + 0.7505, + 0.2955, + 0.5835, + 0.6802, + 0.5518, + 0.5278, + 0.5836, + ] +).reshape((6, 2)) +M_u = np.array( + [ + 0.5118, + 0.3464, + 0.0826, + 0.8865, + 0.7196, + 0.4547, + 0.9962, + 0.4134, + 0.3545, + 0.2177, + 0.9713, + 0.1257, + ] +).reshape((6, 2)) + +z1 = np.array([0.3089, 0.7261, 0.7829, 0.6938, 0.0098, 0.8432]) +z2 = np.array([0.9223, 0.7710, 0.0427, 0.3782, 0.7043, 0.7295]) +y1 = np.array([0.2243, 0.2691]) +y2 = np.array([0.6730, 0.4775]) +x1 = np.array([0.9976, 0.8116, 0.1375, 0.3900]).reshape((2, 2)) +x2 = np.array([0.4857, 0.8944, 0.9274, 0.9175]).reshape((2, 2)) +M_z = [z1, z2] +M_y = [y1, y2] +M_x = [x1, x2] + + +def test_JFATrainAndEnrol(): + # Train and enroll a JFAMachine + + # Calls the train function + ubm = GMMMachine(2, 3) + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + it = JFAMachine(2, 2, em_iterations=10, ubm=ubm) + + it.U = copy.deepcopy(M_u) + it.V = copy.deepcopy(M_v) + it.D = copy.deepcopy(M_d) + it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) + + v_ref = np.array( + [ + [0.245364911936476, 0.978133261775424], + [0.769646805052223, 0.940070736856596], + [0.310779202800089, 1.456332053893072], + [0.184760934399551, 2.265139705602147], + [0.701987784039800, 0.081632150899400], + [0.074344030229297, 1.090248340917255], + ], + "float64", + ) + u_ref = np.array( + [ + [0.049424652628448, 0.060480486336896], + [0.178104127464007, 1.884873813495153], + [1.204011484266777, 2.281351307871720], + [7.278512126426286, -0.390966087173334], + [-0.084424326581145, -0.081725474934414], + [4.042143689831097, -0.262576386580701], + ], + "float64", + ) + d_ref = np.array( + [ + 9.648467e-18, + 2.63720683155e-12, + 2.11822157653706e-10, + 9.1047243e-17, + 1.41163442535567e-10, + 3.30581e-19, + ], + "float64", + ) + + eps = 1e-10 + np.testing.assert_allclose(it.V, v_ref, rtol=eps, atol=1e-8) + np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8) + np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8) + + # Calls the enroll function + + Ne = np.array([0.1579, 0.9245, 0.1323, 0.2458]).reshape((2, 2)) + Fe = np.array( + [ + 0.1579, + 0.1925, + 0.3242, + 0.1234, + 0.2354, + 0.2734, + 0.2514, + 0.5874, + 0.3345, + 0.2463, + 0.4789, + 0.5236, + ] + ).reshape((6, 2)) + gse1 = GMMStats(2, 3) + gse1.n = Ne[:, 0] + gse1.sum_px = Fe[:, 0].reshape(2, 3) + gse2 = GMMStats(2, 3) + gse2.n = Ne[:, 1] + gse2.sum_px = Fe[:, 1].reshape(2, 3) + + gse = [gse1, gse2] + latent_y, latent_z = it.enroll_using_stats(gse, 5) + + y_ref = np.array([0.555991469319657, 0.002773650670010], "float64") + z_ref = np.array( + [ + 8.2228e-20, + 3.15216909492e-13, + -1.48616735364395e-10, + 1.0625905e-17, + 3.7150503117895e-11, + 1.71104e-19, + ], + "float64", + ) + + np.testing.assert_allclose(latent_y, y_ref, rtol=eps, atol=1e-8) + np.testing.assert_allclose(latent_z, z_ref, rtol=eps, atol=1e-8) + + +def test_ISVTrainAndEnrol(): + # Train and enroll an 'ISVMachine' + + eps = 1e-10 + d_ref = np.array( + [ + 0.39601136, + 0.07348469, + 0.47712682, + 0.44738127, + 0.43179856, + 0.45086029, + ], + "float64", + ) + u_ref = np.array( + [ + [0.855125642430777, 0.563104284748032], + [-0.325497865404680, 1.923598985291687], + [0.511575659503837, 1.964288663083095], + [9.330165761678115, 1.073623827995043], + [0.511099245664012, 0.278551249248978], + [5.065578541930268, 0.509565618051587], + ], + "float64", + ) + z_ref = np.array( + [ + [ + -0.079315777443826, + 0.092702428248543, + -0.342488761656616, + -0.059922635809136, + 0.133539981073604, + 0.213118695516570, + ] + ], + "float64", + ) + + """ + Calls the train function + """ + ubm = GMMMachine(2, 3) + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + + it = ISVMachine( + ubm=ubm, + r_U=2, + relevance_factor=4.0, + em_iterations=10, + ) + + it.U = copy.deepcopy(M_u) + it = it.fit_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) + + np.testing.assert_allclose(it.D, d_ref, rtol=eps, atol=1e-8) + np.testing.assert_allclose(it.U, u_ref, rtol=eps, atol=1e-8) + + """ + Calls the enroll function + """ + + Ne = np.array([0.1579, 0.9245, 0.1323, 0.2458]).reshape((2, 2)) + Fe = np.array( + [ + 0.1579, + 0.1925, + 0.3242, + 0.1234, + 0.2354, + 0.2734, + 0.2514, + 0.5874, + 0.3345, + 0.2463, + 0.4789, + 0.5236, + ] + ).reshape((6, 2)) + gse1 = GMMStats(2, 3) + gse1.n = Ne[:, 0] + gse1.sum_px = Fe[:, 0].reshape(2, 3) + gse2 = GMMStats(2, 3) + gse2.n = Ne[:, 1] + gse2.sum_px = Fe[:, 1].reshape(2, 3) + + gse = [gse1, gse2] + latent_z = it.enroll_using_stats(gse, 5) + np.testing.assert_allclose(latent_z, z_ref, rtol=eps, atol=1e-8) + + +def test_JFATrainInitialize(): + # Check that the initialization is consistent and using the rng (cf. issue #118) + + eps = 1e-10 + + # UBM GMM + ubm = GMMMachine(2, 3) + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + + # JFA + it = JFAMachine(2, 2, em_iterations=10, ubm=ubm) + # first round + + n_classes = it.estimate_number_of_classes(TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) + u1 = it.U + v1 = it.V + d1 = it.D + + # second round + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) + u2 = it.U + v2 = it.V + d2 = it.D + + np.testing.assert_allclose(u1, u2, rtol=eps, atol=1e-8) + np.testing.assert_allclose(v1, v2, rtol=eps, atol=1e-8) + np.testing.assert_allclose(d1, d2, rtol=eps, atol=1e-8) + + +def test_ISVTrainInitialize(): + + # Check that the initialization is consistent and using the rng (cf. issue #118) + eps = 1e-10 + + # UBM GMM + ubm = GMMMachine(2, 3) + ubm.means = UBM_MEAN.reshape((2, 3)) + ubm.variances = UBM_VAR.reshape((2, 3)) + + # ISV + it = ISVMachine(2, em_iterations=10, ubm=ubm) + # it.rng = rng + + n_classes = it.estimate_number_of_classes(TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) + u1 = copy.deepcopy(it.U) + d1 = copy.deepcopy(it.D) + + # second round + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) + u2 = it.U + d2 = it.D + + np.testing.assert_allclose(u1, u2, rtol=eps, atol=1e-8) + np.testing.assert_allclose(d1, d2, rtol=eps, atol=1e-8) + + +def test_JFAMachine(): + + eps = 1e-10 + + # Creates a UBM + ubm = GMMMachine(2, 3) + ubm.weights = np.array([0.4, 0.6], "float64") + ubm.means = np.array([[1, 6, 2], [4, 3, 2]], "float64") + ubm.variances = np.array([[1, 2, 1], [2, 1, 2]], "float64") + + # Defines GMMStats + gs = GMMStats(2, 3) + gs.log_likelihood = -3.0 + gs.t = 1 + gs.n = np.array([0.4, 0.6], "float64") + gs.sum_px = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "float64") + gs.sum_pxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64") + + # Creates a JFAMachine + m = JFAMachine(2, 2, em_iterations=10, ubm=ubm) + m.U = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64" + ) + m.V = np.array([[6, 5], [4, 3], [2, 1], [1, 2], [3, 4], [5, 6]], "float64") + m.D = np.array([0, 1, 0, 1, 0, 1], "float64") + + # Preparing the model + y = np.array([1, 2], "float64") + z = np.array([3, 4, 1, 2, 0, 1], "float64") + model = [y, z] + + score_ref = -2.111577181208289 + score = m.score_using_stats(model, gs) + np.testing.assert_allclose(score, score_ref, atol=eps) + + # Scoring with numpy array + np.random.seed(0) + X = np.random.normal(loc=0.0, scale=1.0, size=(50, 3)) + score_ref = 2.028009315286946 + score = m.score(model, X) + np.testing.assert_allclose(score, score_ref, atol=eps) + + +def test_ISVMachine(): + + eps = 1e-10 + + # Creates a UBM + ubm = GMMMachine(2, 3) + ubm.weights = np.array([0.4, 0.6], "float64") + ubm.means = np.array([[1, 6, 2], [4, 3, 2]], "float64") + ubm.variances = np.array([[1, 2, 1], [2, 1, 2]], "float64") + + # Creates a ISVMachine + isv_machine = ISVMachine(ubm=ubm, r_U=2, em_iterations=10) + isv_machine.U = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], "float64" + ) + # base.v = np.array([[0], [0], [0], [0], [0], [0]], 'float64') + isv_machine.D = np.array([0, 1, 0, 1, 0, 1], "float64") + + # Defines GMMStats + gs = GMMStats(2, 3) + gs.log_likelihood = -3.0 + gs.t = 1 + gs.n = np.array([0.4, 0.6], "float64") + gs.sum_px = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], "float64") + gs.sum_pxx = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], "float64") + + # Enrolled model + latent_z = np.array([3, 4, 1, 2, 0, 1], "float64") + score = isv_machine.score_using_stats(latent_z, gs) + score_ref = -3.280498193082100 + np.testing.assert_allclose(score, score_ref, atol=eps) + + # Scoring with numpy array + np.random.seed(0) + X = np.random.normal(loc=0.0, scale=1.0, size=(50, 3)) + score_ref = -1.2343813195374242 + score = isv_machine.score(latent_z, X) + np.testing.assert_allclose(score, score_ref, atol=eps) + + +def _create_ubm_prior(means): + # Creating a fake prior with 2 gaussians + prior_gmm = GMMMachine(2) + prior_gmm.means = means.copy() + # All nice and round diagonal covariance + prior_gmm.variances = np.ones((2, 3)) * 0.5 + prior_gmm.weights = np.array([0.3, 0.7]) + return prior_gmm + + +def test_ISV_JFA_fit(): + np.random.seed(10) + data_class1 = np.random.normal(0, 0.5, (10, 3)) + data_class2 = np.random.normal(-0.2, 0.2, (10, 3)) + data = np.concatenate([data_class1, data_class2], axis=0) + labels = [0] * 10 + [1] * 10 + means = np.vstack( + (np.random.normal(0, 0.5, (1, 3)), np.random.normal(1, 0.5, (1, 3))) + ) + prior_U = [ + [-0.150035, -0.44441], + [-1.67812, 2.47621], + [-0.52885, 0.659141], + [-0.538446, 1.67376], + [-0.111288, 2.06948], + [1.39563, -1.65004], + ] + + prior_V = [ + [0.732467, 0.281321], + [0.543212, -0.512974], + [1.04108, 0.835224], + [-0.363719, -0.324688], + [-1.21579, -0.905314], + [-0.993204, -0.121991], + ] + + prior_D = [ + 0.943986, + -0.0900599, + -0.528103, + 0.541502, + -0.717824, + 0.463729, + ] + + for prior, machine_type, ref in [ + ( + None, + "isv", + 0.0, + ), + ( + True, + "isv", + [ + [-0.01018673, -0.0266506], + [-0.00160621, -0.00420217], + [0.02811705, 0.07356008], + [0.011624, 0.0304108], + [0.03261831, 0.08533629], + [0.04602191, 0.12040291], + ], + ), + ( + None, + "jfa", + [ + [-0.05673845, -0.0543068], + [-0.05302666, -0.05075409], + [-0.02522509, -0.02414402], + [-0.05723968, -0.05478655], + [-0.05291602, -0.05064819], + [-0.02463007, -0.0235745], + ], + ), + ( + True, + "jfa", + [ + [0.002881, -0.00584225], + [0.04143539, -0.08402497], + [-0.26149924, 0.53028251], + [-0.25156832, 0.51014406], + [-0.38687765, 0.78453174], + [-0.36015821, 0.73034858], + ], + ), + ]: + ref = np.asarray(ref) + + # Doing the training + for transform in (to_numpy, to_dask_array): + data, labels = transform(data, labels) + + if prior is None: + ubm = None + # we still provide an initial UBM because KMeans training is not + # determenistic depending on inputting numpy or dask arrays + ubm_kwargs = dict(n_gaussians=2, ubm=_create_ubm_prior(means)) + else: + ubm = _create_ubm_prior(means) + ubm_kwargs = None + + machine_kwargs = dict( + ubm=ubm, + relevance_factor=4, + em_iterations=50, + ubm_kwargs=ubm_kwargs, + random_state=10, + ) + + if machine_type == "isv": + machine = ISVMachine(2, **machine_kwargs) + machine.U = prior_U + test_attr = "U" + else: + machine = JFAMachine(2, 2, **machine_kwargs) + machine.U = prior_U + machine.V = prior_V + machine.D = prior_D + test_attr = "V" + + err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}" + with multiprocess_dask_client(): + machine.fit(data, labels) + + arr = getattr(machine, test_attr) + np.testing.assert_allclose( + arr, + ref, + atol=1e-7, + err_msg=err_msg, + ) diff --git a/bob/learn/em/utils.py b/bob/learn/em/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..be92e8d41258384e08b06a07d212d1da328e048f --- /dev/null +++ b/bob/learn/em/utils.py @@ -0,0 +1,34 @@ +import logging + +import dask +import dask.array as da +import numpy as np + +logger = logging.getLogger(__name__) + + +def check_and_persist_dask_input(data, persist=True): + # check if input is a dask array. If so, persist and rebalance data + input_is_dask = False + if isinstance(data, da.Array): + if persist: + data: da.Array = data.persist() + input_is_dask = True + # if there is a dask distributed client, rebalance data + try: + client = dask.distributed.Client.current() + client.rebalance() + except ValueError: + pass + + else: + data = np.asarray(data) + return input_is_dask, data + + +def array_to_delayed_list(data, input_is_dask): + # If input is a dask array, convert to delayed chunks + if input_is_dask: + data = data.to_delayed().ravel().tolist() + logger.debug(f"Got {len(data)} chunks.") + return data diff --git a/bob/learn/em/wccn.py b/bob/learn/em/wccn.py index 7c2c8246c4f70c80f3ae5499ce70db65c8976f65..47d5b80463ebb74e0c1e7f8b8dccc1f52ed31529 100644 --- a/bob/learn/em/wccn.py +++ b/bob/learn/em/wccn.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# @author: Tiago de Freitas Pereira + import dask # Dask doesn't have an implementation for `pinv` diff --git a/bob/learn/em/whitening.py b/bob/learn/em/whitening.py index 3565f4fb2397bb7aadcd5e81727e79562056270e..bf81a36203580120e2d3f90664bed38401809d58 100644 --- a/bob/learn/em/whitening.py +++ b/bob/learn/em/whitening.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python +# @author: Tiago de Freitas Pereira + import dask from scipy.linalg import pinv diff --git a/buildout.cfg b/buildout.cfg index d3c5cee7a448e29f08c7627e5a14591ec2e73122..044f2367cdf0522cb7d4a8b7159e47be68d2f696 100644 --- a/buildout.cfg +++ b/buildout.cfg @@ -8,6 +8,7 @@ eggs = bob.learn.em extensions = bob.buildout newest = false verbose = true +debug = true [scripts] recipe = bob.buildout:scripts diff --git a/doc/extra-intersphinx.txt b/doc/extra-intersphinx.txt deleted file mode 100644 index e0c2a1af70a9f804401fb9c4654cb2338090983c..0000000000000000000000000000000000000000 --- a/doc/extra-intersphinx.txt +++ /dev/null @@ -1,2 +0,0 @@ -# The bob.core>2.0.5 in the requirements.txt is making the bob.core not download -bob.core diff --git a/doc/guide.rst b/doc/guide.rst index 82d0b95ac172b6e14aa8ccb8220fc2b6405632ce..c445a32d44d5ef38d5384d4bd0063caa1584c42c 100644 --- a/doc/guide.rst +++ b/doc/guide.rst @@ -101,8 +101,8 @@ This statistical model is defined in the class :options: +NORMALIZE_WHITESPACE +SKIP >>> import bob.learn.em - >>> # Create a GMM with k=2 Gaussians with the dimensionality of 3 - >>> gmm_machine = bob.learn.em.GMMMachine(2, 3) + >>> # Create a GMM with k=2 Gaussians + >>> gmm_machine = bob.learn.em.GMMMachine(n_gaussians=2) There are plenty of ways to estimate :math:`\Theta`; the next subsections @@ -118,7 +118,7 @@ the parameters of a statistical model given observations by finding the :math:`\Theta` that maximizes :math:`P(x|\Theta)` for all :math:`x` in your dataset [9]_. This optimization is done by the **Expectation-Maximization** (EM) algorithm [8]_ and it is implemented by -:py:class:`bob.learn.em.ML_GMMTrainer`. +:py:class:`bob.learn.em.GMMMachine` by setting the keyword argument `trainer="ml"`. A very nice explanation of EM algorithm for the maximum likelihood estimation can be found in this @@ -130,7 +130,7 @@ estimator. .. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP + :options: +NORMALIZE_WHITESPACE >>> import bob.learn.em >>> import numpy @@ -141,35 +141,22 @@ estimator. ... [-7,7,-100], ... [-5,5,-101]], dtype='float64') >>> # Create a kmeans model (machine) m with k=2 clusters - >>> # with a dimensionality equal to 3 - >>> gmm_machine = bob.learn.em.GMMMachine(2, 3) - >>> # Using the MLE trainer to train the GMM: - >>> # True, True, True means update means/variances/weights at each - >>> # iteration - >>> gmm_trainer = bob.learn.em.ML_GMMTrainer(True, True, True) - >>> # Setting some means to start the training. - >>> # In practice, the output of kmeans is a good start for the MLE training - >>> gmm_machine.means = numpy.array( - ... [[ -4., 2.3, -10.5], - ... [ 2.5, -4.5, 59. ]]) - >>> max_iterations = 200 - >>> convergence_threshold = 1e-5 + >>> # and using the MLE trainer to train the GMM: + >>> # In this setup, kmeans is used to initialize the means, variances and weights of the gaussians + >>> gmm_machine = bob.learn.em.GMMMachine(n_gaussians=2, trainer="ml") >>> # Training - >>> bob.learn.em.train(gmm_trainer, gmm_machine, data, - ... max_iterations=max_iterations, - ... convergence_threshold=convergence_threshold) + >>> gmm_machine = gmm_machine.fit(data) >>> print(gmm_machine.means) - [[ -6. 6. -100.5] - [ 3.5 -3.5 99. ]] + [[ 3.5 -3.5 99. ] + [ -6. 6. -100.5]] Bellow follow an intuition of the GMM trained the maximum likelihood estimator using the Iris flower `dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_. -.. - TODO uncomment when implemented - .. plot:: plot/plot_ML.py - :include-source: False + +.. plot:: plot/plot_ML.py + :include-source: False Maximum a posteriori Estimator (MAP) @@ -181,7 +168,7 @@ estimate that equals the mode of the posterior distribution by incorporating in its loss function a prior distribution [10]_. Commonly this prior distribution (the values of :math:`\Theta`) is estimated with MLE. This optimization is done by the **Expectation-Maximization** (EM) algorithm [8]_ and it is implemented -by :py:class:`bob.learn.em.MAP_GMMTrainer`. +by :py:class:`bob.learn.em.GMMMachine` by setting the keyword argument `trainer="map"`. A compact way to write relevance MAP adaptation is by using GMM supervector notation (this will be useful in the next subsections). The GMM supervector @@ -195,7 +182,7 @@ Follow bellow an snippet on how to train a GMM using the MAP estimator. .. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP + :options: +NORMALIZE_WHITESPACE >>> import bob.learn.em >>> import numpy @@ -206,33 +193,31 @@ Follow bellow an snippet on how to train a GMM using the MAP estimator. ... [-7,7,-100], ... [-5,5,-101]], dtype='float64') >>> # Creating a fake prior - >>> prior_gmm = bob.learn.em.GMMMachine(2, 3) - >>> # Set some random means for the example + >>> prior_gmm = bob.learn.em.GMMMachine(2) + >>> # Set some random means/variances and weights for the example >>> prior_gmm.means = numpy.array( ... [[ -4., 2.3, -10.5], ... [ 2.5, -4.5, 59. ]]) - >>> # Creating the model for the adapted GMM - >>> adapted_gmm = bob.learn.em.GMMMachine(2, 3) - >>> # Creating the MAP trainer - >>> gmm_trainer = bob.learn.em.MAP_GMMTrainer(prior_gmm, relevance_factor=4) - >>> - >>> max_iterations = 200 - >>> convergence_threshold = 1e-5 + >>> prior_gmm.variances = numpy.array( + ... [[ -0.1, 0.5, -0.5], + ... [ 0.5, -0.5, 0.2 ]]) + >>> prior_gmm.weights = numpy.array([ 0.8, 0.5]) + >>> # Creating the model for the adapted GMM, and setting the `prior_gmm` as the source GMM + >>> # note that we have set `trainer="map"`, so we use the Maximum a posteriori estimator + >>> adapted_gmm = bob.learn.em.GMMMachine(2, ubm=prior_gmm, trainer="map") >>> # Training - >>> bob.learn.em.train(gmm_trainer, adapted_gmm, data, - ... max_iterations=max_iterations, - ... convergence_threshold=convergence_threshold) + >>> adapted_gmm = adapted_gmm.fit(data) >>> print(adapted_gmm.means) - [[ -4.667 3.533 -40.5 ] - [ 2.929 -4.071 76.143]] + [[ -4. 2.3 -10.5 ] + [ 0.944 -1.833 36.889]] Bellow follow an intuition of the GMM trained with the MAP estimator using the Iris flower `dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_. +It can be observed how the MAP means (the red triangles) around the center of each class +from a prior GMM (the blue crosses). -.. - TODO uncomment when implemented - .. plot:: plot/plot_MAP.py - :include-source: False +.. plot:: plot/plot_MAP.py + :include-source: False Session Variability Modeling with Gaussian Mixture Models @@ -273,7 +258,7 @@ prior GMM. .. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP + :options: +NORMALIZE_WHITESPACE >>> import bob.learn.em >>> import numpy @@ -285,21 +270,13 @@ prior GMM. ... [-0.3, -0.1, 0], ... [1.2, 1.4, 1], ... [0.8, 1., 1]], dtype='float64') - >>> # Creating a fake prior with 2 Gaussians of dimension 3 - >>> prior_gmm = bob.learn.em.GMMMachine(2, 3) - >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)), - ... numpy.random.normal(1, 0.5, (1, 3)))) - >>> # All nice and round diagonal covariance - >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5 - >>> prior_gmm.weights = numpy.array([0.3, 0.7]) + >>> # Training a GMM with 2 Gaussians of dimension 3 + >>> prior_gmm = bob.learn.em.GMMMachine(2).fit(data) >>> # Creating the container - >>> gmm_stats_container = bob.learn.em.GMMStats(2, 3) - >>> for d in data: - ... prior_gmm.acc_statistics(d, gmm_stats_container) - >>> + >>> gmm_stats = prior_gmm.transform(data) >>> # Printing the responsibilities - >>> print(gmm_stats_container.n/gmm_stats_container.t) - [0.429 0.571] + >>> print(gmm_stats.n/gmm_stats.t) + [0.6 0.4] Inter-Session Variability @@ -307,80 +284,69 @@ Inter-Session Variability .. _isv: Inter-Session Variability (ISV) modeling [3]_ [2]_ is a session variability -modeling technique built on top of the Gaussian mixture modeling approach. It -hypothesizes that within-class variations are embedded in a linear subspace in -the GMM means subspace and these variations can be suppressed by an offset w.r.t -each mean during the MAP adaptation. +modeling technique built on top of the Gaussian mixture modeling approach. +It hypothesizes that within-class variations are embedded in a linear subspace in +the GMM means subspace, and these variations can be suppressed by an offset w.r.t each mean during the MAP adaptation. -In this generative model each sample is assumed to have been generated by a GMM -mean supervector with the following shape: + +In this generative model, each sample is assumed to have been generated by a GMM mean supervector with the following shape: :math:`\mu_{i, j} = m + Ux_{i, j} + D_z{i}`, where :math:`m` is our prior, :math:`Ux_{i, j}` is the session offset that we want to suppress and :math:`D_z{i}` is the class offset (with all session effects suppressed). -All possible sources of session variations is embedded in this matrix -:math:`U`. Follow bellow an intuition of what is modeled with :math:`U` in the +It is hypothesized that all possible sources of session variations are embedded in this matrix +:math:`U`. Follow below an intuition of what is modeled with :math:`U` in the Iris flower `dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_. The arrows :math:`U_{1}`, :math:`U_{2}` and :math:`U_{3}` are the directions of -the within class variations, with respect to each Gaussian component, that will +the within-class variations, with respect to each Gaussian component, that will be suppressed a posteriori. -.. - TODO uncomment when implemented - .. plot:: plot/plot_ISV.py - :include-source: False + + +.. plot:: plot/plot_ISV.py + :include-source: False The ISV statistical model is stored in this container -:py:class:`bob.learn.em.ISVBase` and the training is performed by -:py:class:`bob.learn.em.ISVTrainer`. The snippet bellow shows how to train a -Intersession variability modeling. +:py:class:`bob.learn.em.ISVMachine`. +The snippet bellow shows how to: + + - Train a Intersession variability modeling. + - Enroll a subject with such a model. + - Compute score with such a model. .. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP + :options: +NORMALIZE_WHITESPACE + + >>> import bob.learn.em + >>> import numpy as np + + >>> np.random.seed(10) + + >>> # Generating some fake data + >>> data_class1 = np.random.normal(0, 0.5, (10, 3)) + >>> data_class2 = np.random.normal(-0.2, 0.2, (10, 3)) + >>> X = np.vstack((data_class1, data_class2)) + >>> y = np.hstack((np.zeros(10, dtype=int), np.ones(10, dtype=int))) + >>> # Create an ISV machine with a UBM of 2 gaussians + >>> isv_machine = bob.learn.em.ISVMachine(r_U=2, ubm_kwargs=dict(n_gaussians=2)) + >>> _ = isv_machine.fit(X, y) # DOCTEST: +SKIP_ + >>> isv_machine.U + array(...) + + >>> # Enrolling a subject + >>> enroll_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]]) + >>> model = isv_machine.enroll(enroll_data) + >>> print(model) + [[ 0.54 0.246 0.505 1.617 -0.791 0.746]] + + >>> # Probing + >>> probe_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]]) + >>> score = isv_machine.score(model, probe_data) + >>> print(score) + [2.754] - >>> import bob.learn.em - >>> import numpy - >>> numpy.random.seed(10) - >>> - >>> # Generating some fake data - >>> data_class1 = numpy.random.normal(0, 0.5, (10, 3)) - >>> data_class2 = numpy.random.normal(-0.2, 0.2, (10, 3)) - >>> data = [data_class1, data_class2] - - >>> # Creating a fake prior with 2 gaussians of dimension 3 - >>> prior_gmm = bob.learn.em.GMMMachine(2, 3) - >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)), - ... numpy.random.normal(1, 0.5, (1, 3)))) - >>> # All nice and round diagonal covariance - >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5 - >>> prior_gmm.weights = numpy.array([0.3, 0.7]) - >>> # The input the the ISV Training is the statistics of the GMM - >>> gmm_stats_per_class = [] - >>> for d in data: - ... stats = [] - ... for i in d: - ... gmm_stats_container = bob.learn.em.GMMStats(2, 3) - ... prior_gmm.acc_statistics(i, gmm_stats_container) - ... stats.append(gmm_stats_container) - ... gmm_stats_per_class.append(stats) - - >>> # Finally doing the ISV training - >>> subspace_dimension_of_u = 2 - >>> relevance_factor = 4 - >>> isvbase = bob.learn.em.ISVBase(prior_gmm, subspace_dimension_of_u) - >>> trainer = bob.learn.em.ISVTrainer(relevance_factor) - >>> bob.learn.em.train(trainer, isvbase, gmm_stats_per_class, - ... max_iterations=50) - >>> # Printing the session offset w.r.t each Gaussian component - >>> print(isvbase.u) - [[-0.01 -0.027] - [-0.002 -0.004] - [ 0.028 0.074] - [ 0.012 0.03 ] - [ 0.033 0.085] - [ 0.046 0.12 ]] Joint Factor Analysis @@ -401,308 +367,47 @@ between class variations with respect to each Gaussian component that will be added a posteriori. -.. - TODO uncomment when implemented - .. plot:: plot/plot_JFA.py - :include-source: False +.. plot:: plot/plot_JFA.py + :include-source: False The JFA statistical model is stored in this container -:py:class:`bob.learn.em.JFABase` and the training is performed by -:py:class:`bob.learn.em.JFATrainer`. The snippet bellow shows how to train a -Intersession variability modeling. - -.. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP - - >>> import bob.learn.em - >>> import numpy - >>> numpy.random.seed(10) - >>> - >>> # Generating some fake data - >>> data_class1 = numpy.random.normal(0, 0.5, (10, 3)) - >>> data_class2 = numpy.random.normal(-0.2, 0.2, (10, 3)) - >>> data = [data_class1, data_class2] - - >>> # Creating a fake prior with 2 Gaussians of dimension 3 - >>> prior_gmm = bob.learn.em.GMMMachine(2, 3) - >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)), - ... numpy.random.normal(1, 0.5, (1, 3)))) - >>> # All nice and round diagonal covariance - >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5 - >>> prior_gmm.weights = numpy.array([0.3, 0.7]) - >>> - >>> # The input the the JFA Training is the statistics of the GMM - >>> gmm_stats_per_class = [] - >>> for d in data: - ... stats = [] - ... for i in d: - ... gmm_stats_container = bob.learn.em.GMMStats(2, 3) - ... prior_gmm.acc_statistics(i, gmm_stats_container) - ... stats.append(gmm_stats_container) - ... gmm_stats_per_class.append(stats) - >>> - >>> # Finally doing the JFA training - >>> subspace_dimension_of_u = 2 - >>> subspace_dimension_of_v = 2 - >>> relevance_factor = 4 - >>> jfabase = bob.learn.em.JFABase(prior_gmm, subspace_dimension_of_u, - ... subspace_dimension_of_v) - >>> trainer = bob.learn.em.JFATrainer() - >>> bob.learn.em.train_jfa(trainer, jfabase, gmm_stats_per_class, - ... max_iterations=50) - - >>> # Printing the session offset w.r.t each Gaussian component - >>> print(jfabase.v) - [[ 0.003 -0.006] - [ 0.041 -0.084] - [-0.261 0.53 ] - [-0.252 0.51 ] - [-0.387 0.785] - [-0.36 0.73 ]] - -Total variability Modelling -=========================== -.. _ivector: - -Total Variability (TV) modeling [4]_ is a front-end initially introduced for -speaker recognition, which aims at describing samples by vectors of low -dimensionality called ``i-vectors``. The model consists of a subspace :math:`T` -and a residual diagonal covariance matrix :math:`\Sigma`, that are then used to -extract i-vectors, and is built upon the GMM approach. In the supervector -notation this modeling has the following shape: :math:`\mu = m + Tv`. - -Follow bellow an intuition of the data from the Iris flower -`dataset <https://en.wikipedia.org/wiki/Iris_flower_data_set>`_, embedded in -the iVector space. - -.. - TODO uncomment when implemented - .. plot:: plot/plot_iVector.py - :include-source: False - - -The iVector statistical model is stored in this container -:py:class:`bob.learn.em.IVectorMachine` and the training is performed by -:py:class:`bob.learn.em.IVectorTrainer`. The snippet bellow shows how to train -a Total variability modeling. - -.. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP - - >>> import bob.learn.em - >>> import numpy - >>> numpy.random.seed(10) - >>> - >>> # Generating some fake data - >>> data_class1 = numpy.random.normal(0, 0.5, (10, 3)) - >>> data_class2 = numpy.random.normal(-0.2, 0.2, (10, 3)) - >>> data = [data_class1, data_class2] - >>> - >>> # Creating a fake prior with 2 gaussians of dimension 3 - >>> prior_gmm = bob.learn.em.GMMMachine(2, 3) - >>> prior_gmm.means = numpy.vstack((numpy.random.normal(0, 0.5, (1, 3)), - ... numpy.random.normal(1, 0.5, (1, 3)))) - >>> # All nice and round diagonal covariance - >>> prior_gmm.variances = numpy.ones((2, 3)) * 0.5 - >>> prior_gmm.weights = numpy.array([0.3, 0.7]) - >>> - >>> # The input the the TV Training is the statistics of the GMM - >>> gmm_stats_per_class = [] - >>> for d in data: - ... for i in d: - ... gmm_stats_container = bob.learn.em.GMMStats(2, 3) - ... prior_gmm.acc_statistics(i, gmm_stats_container) - ... gmm_stats_per_class.append(gmm_stats_container) - >>> - >>> # Finally doing the TV training - >>> subspace_dimension_of_t = 2 - >>> - >>> ivector_trainer = bob.learn.em.IVectorTrainer(update_sigma=True) - >>> ivector_machine = bob.learn.em.IVectorMachine( - ... prior_gmm, subspace_dimension_of_t, 10e-5) - >>> # train IVector model - >>> bob.learn.em.train(ivector_trainer, ivector_machine, - ... gmm_stats_per_class, 500) - >>> - >>> # Printing the session offset w.r.t each Gaussian component - >>> print(ivector_machine.t) - [[ 0.11 -0.203] - [-0.124 0.014] - [ 0.296 0.674] - [ 0.447 0.174] - [ 0.425 0.583] - [ 0.394 0.794]] - -Linear Scoring -============== -.. _linearscoring: - -In :ref:`MAP <map>` adaptation, :ref:`ISV <isv>` and :ref:`JFA <jfa>` a -traditional way to do scoring is via the log-likelihood ratio between the -adapted model and the prior as the following: - -.. math:: - score = ln(P(x | \Theta)) - ln(P(x | \Theta_{prior})), +:py:class:`bob.learn.em.JFAMachine`. The snippet bellow shows how to train a +such session variability model. + - Train a JFA model. + - Enroll a subject with such a model. + - Compute score with such a model. -(with :math:`\Theta` varying for each approach). - -A simplification proposed by [Glembek2009]_, called linear scoring, -approximate this ratio using a first order Taylor series as the following: - -.. math:: - score = \frac{\mu - \mu_{prior}}{\sigma_{prior}} f * (\mu_{prior} + U_x), - -where :math:`\mu` is the the GMM mean supervector (of the prior and the adapted -model), :math:`\sigma` is the variance, supervector, :math:`f` is the first -order GMM statistics (:py:class:`bob.learn.em.GMMStats.sum_px`) and -:math:`U_x`, is possible channel offset (:ref:`ISV <isv>`). - -This scoring technique is implemented in :py:func:`bob.learn.em.linear_scoring`. -The snippet bellow shows how to compute scores using this approximation. .. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP + :options: +NORMALIZE_WHITESPACE >>> import bob.learn.em - >>> import numpy - >>> # Defining a fake prior - >>> prior_gmm = bob.learn.em.GMMMachine(3, 2) - >>> prior_gmm.means = numpy.array([[1, 1], [2, 2.1], [3, 3]]) - >>> # Defining a fake prior - >>> adapted_gmm = bob.learn.em.GMMMachine(3,2) - >>> adapted_gmm.means = numpy.array([[1.5, 1.5], [2.5, 2.5], [2, 2]]) - >>> # Defining an input - >>> input = numpy.array([[1.5, 1.5], [1.6, 1.6]]) - >>> #Accumulating statistics of the GMM - >>> stats = bob.learn.em.GMMStats(3, 2) - >>> prior_gmm.acc_statistics(input, stats) - >>> score = bob.learn.em.linear_scoring( - ... [adapted_gmm], prior_gmm, [stats], [], - ... frame_length_normalisation=True) + >>> import numpy as np + + >>> np.random.seed(10) + + >>> # Generating some fake data + >>> data_class1 = np.random.normal(0, 0.5, (10, 3)) + >>> data_class2 = np.random.normal(-0.2, 0.2, (10, 3)) + >>> X = np.vstack((data_class1, data_class2)) + >>> y = np.hstack((np.zeros(10, dtype=int), np.ones(10, dtype=int))) + >>> # Create a JFA machine with a UBM of 2 gaussians + >>> jfa_machine = bob.learn.em.JFAMachine(r_U=2, r_V=2, ubm_kwargs=dict(n_gaussians=2)) + >>> _ = jfa_machine.fit(X, y) + >>> jfa_machine.U + array(...) + + >>> enroll_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]]) + >>> model = jfa_machine.enroll(enroll_data) + >>> print(model) + (array([0.634, 0.165]), array([ 0., 0., 0., 0., -0., 0.])) + + >>> probe_data = np.array([[1.2, 0.1, 1.4], [0.5, 0.2, 0.3]]) + >>> score = jfa_machine.score(model, probe_data) >>> print(score) - [[0.254]] - - -Probabilistic Linear Discriminant Analysis (PLDA) -------------------------------------------------- - -Probabilistic Linear Discriminant Analysis [5]_ is a probabilistic model that -incorporates components describing both between-class and within-class -variations. Given a mean :math:`\mu`, between-class and within-class subspaces -:math:`F` and :math:`G` and residual noise :math:`\epsilon` with zero mean and -diagonal covariance matrix :math:`\Sigma`, the model assumes that a sample -:math:`x_{i,j}` is generated by the following process: - -.. math:: - - x_{i,j} = \mu + F h_{i} + G w_{i,j} + \epsilon_{i,j} - - -An Expectation-Maximization algorithm can be used to learn the parameters of -this model :math:`\mu`, :math:`F` :math:`G` and :math:`\Sigma`. As these -parameters can be shared between classes, there is a specific container class -for this purpose, which is :py:class:`bob.learn.em.PLDABase`. The process is -described in detail in [6]_. - -Let us consider a training set of two classes, each with 3 samples of -dimensionality 3. - -.. doctest:: - :options: +NORMALIZE_WHITESPACE +SKIP - - >>> data1 = numpy.array( - ... [[3,-3,100], - ... [4,-4,50], - ... [40,-40,150]], dtype=numpy.float64) - >>> data2 = numpy.array( - ... [[3,6,-50], - ... [4,8,-100], - ... [40,79,-800]], dtype=numpy.float64) - >>> data = [data1,data2] - -Learning a PLDA model can be performed by instantiating the class -:py:class:`bob.learn.em.PLDATrainer`, and calling the -:py:meth:`bob.learn.em.train` method. - -.. doctest:: - :options: +SKIP - - >>> # This creates a PLDABase container for input feature of dimensionality - >>> # 3 and with subspaces F and G of rank 1 and 2, respectively. - >>> pldabase = bob.learn.em.PLDABase(3,1,2) - - >>> trainer = bob.learn.em.PLDATrainer() - >>> bob.learn.em.train(trainer, pldabase, data, max_iterations=10) - -Once trained, this PLDA model can be used to compute the log-likelihood of a -set of samples given some hypothesis. For this purpose, a -:py:class:`bob.learn.em.PLDAMachine` should be instantiated. Then, the -log-likelihood that a set of samples share the same latent identity variable -:math:`h_{i}` (i.e. the samples are coming from the same identity/class) is -obtained by calling the -:py:meth:`bob.learn.em.PLDAMachine.compute_log_likelihood()` method. - -.. doctest:: - :options: +SKIP - - >>> plda = bob.learn.em.PLDAMachine(pldabase) - >>> samples = numpy.array( - ... [[3.5,-3.4,102], - ... [4.5,-4.3,56]], dtype=numpy.float64) - >>> loglike = plda.compute_log_likelihood(samples) - -If separate models for different classes need to be enrolled, each of them with -a set of enrollment samples, then, several instances of -:py:class:`bob.learn.em.PLDAMachine` need to be created and enrolled using -the :py:meth:`bob.learn.em.PLDATrainer.enroll()` method as follows. - -.. doctest:: - :options: +SKIP - - >>> plda1 = bob.learn.em.PLDAMachine(pldabase) - >>> samples1 = numpy.array( - ... [[3.5,-3.4,102], - ... [4.5,-4.3,56]], dtype=numpy.float64) - >>> trainer.enroll(plda1, samples1) - >>> plda2 = bob.learn.em.PLDAMachine(pldabase) - >>> samples2 = numpy.array( - ... [[3.5,7,-49], - ... [4.5,8.9,-99]], dtype=numpy.float64) - >>> trainer.enroll(plda2, samples2) - -Afterwards, the joint log-likelihood of the enrollment samples and of one or -several test samples can be computed as previously described, and this -separately for each model. - -.. doctest:: - :options: +SKIP - - >>> sample = numpy.array([3.2,-3.3,58], dtype=numpy.float64) - >>> l1 = plda1.compute_log_likelihood(sample) - >>> l2 = plda2.compute_log_likelihood(sample) - -In a verification scenario, there are two possible hypotheses: - -#. :math:`x_{test}` and :math:`x_{enroll}` share the same class. -#. :math:`x_{test}` and :math:`x_{enroll}` are from different classes. - -Using the methods :py:meth:`bob.learn.em.PLDAMachine.log_likelihood_ratio` or -its alias ``__call__`` function, the corresponding log-likelihood ratio will be -computed, which is defined in more formal way by: -:math:`s = \ln(P(x_{test},x_{enroll})) - \ln(P(x_{test})P(x_{enroll}))` - -.. doctest:: - :options: +SKIP - - >>> s1 = plda1(sample) - >>> s2 = plda2(sample) - -.. testcleanup:: * + [0.471] - import shutil - os.chdir(current_directory) - shutil.rmtree(temp_dir) @@ -711,9 +416,6 @@ computed, which is defined in more formal way by: .. [1] http://dx.doi.org/10.1109/TASL.2006.881693 .. [2] http://publications.idiap.ch/index.php/publications/show/2606 .. [3] http://dx.doi.org/10.1016/j.csl.2007.05.003 -.. [4] http://dx.doi.org/10.1109/TASL.2010.2064307 -.. [5] http://dx.doi.org/10.1109/ICCV.2007.4409052 -.. [6] http://doi.ieeecomputersociety.org/10.1109/TPAMI.2013.38 .. [7] http://en.wikipedia.org/wiki/K-means_clustering .. [8] http://en.wikipedia.org/wiki/Expectation-maximization_algorithm .. [9] http://en.wikipedia.org/wiki/Maximum_likelihood diff --git a/doc/index.rst b/doc/index.rst index d81413146ca3255cce7f5462ff8d295c14cb7a03..2e35eefa5c57f8c340a67046594dfced82ff60a6 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -37,7 +37,9 @@ References .. .. [Vogt2008] *R. Vogt, S. Sridharan*. **'Explicit Modelling of Session Variability for Speaker Verification'**, Computer Speech & Language, 2008, vol. 22, no. 1, pp. 17-38 .. - .. [McCool2013] *C. McCool, R. Wallace, M. McLaren, L. El Shafey, S. Marcel*. **'Session Variability Modelling for Face Authentication'**, IET Biometrics, 2013 + +.. [McCool2013] *C. McCool, R. Wallace, M. McLaren, L. El Shafey, S. Marcel*. **'Session Variability Modelling for Face Authentication'**, IET Biometrics, 2013 + .. .. [ElShafey2014] *Laurent El Shafey, Chris McCool, Roy Wallace, Sebastien Marcel*. **'A Scalable Formulation of Probabilistic Linear Discriminant Analysis: Applied to Face Recognition'**, TPAMI'2014 .. @@ -49,7 +51,6 @@ References .. .. [Roweis1998] Roweis, Sam. "EM algorithms for PCA and SPCA." Advances in neural information processing systems (1998): 626-632. -.. [Glembek2009] Glembek, Ondrej, et al. "Comparison of scoring methods used in speaker recognition with joint factor analysis." Acoustics, Speech and Signal Processing, 2009. ICASSP 2009. IEEE International Conference on. IEEE, 2009. Indices and tables diff --git a/doc/plot/plot_ISV.py b/doc/plot/plot_ISV.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a9c39617ddfbce6970078fc45dae99e66b750a --- /dev/null +++ b/doc/plot/plot_ISV.py @@ -0,0 +1,130 @@ +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import load_iris + +import bob.learn.em + +np.random.seed(2) # FIXING A SEED + + +# GENERATING DATA +iris_data = load_iris() +X = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3])) +y = iris_data.target + +setosa = X[iris_data.target == 0] +versicolor = X[iris_data.target == 1] +virginica = X[iris_data.target == 2] + +n_gaussians = 3 +r_U = 1 + + +# TRAINING THE PRIOR +ubm = bob.learn.em.GMMMachine(n_gaussians) +# Initializing with old bob initialization +ubm.means = np.array( + [ + [5.0048631, 0.26047739], + [5.83509503, 1.40530362], + [6.76257257, 1.98965356], + ] +) +ubm.variances = np.array( + [ + [0.11311728, 0.05183813], + [0.11587106, 0.08492455], + [0.20482993, 0.10438209], + ] +) + +ubm.weights = np.array([0.36, 0.36, 0.28]) + +isv_machine = bob.learn.em.ISVMachine(r_U, em_iterations=50, ubm=ubm) +isv_machine.U = np.array( + [[-0.150035, -0.44441, -1.67812, 2.47621, -0.52885, 0.659141]] +).T + +isv_machine = isv_machine.fit(X, y) + +# Variability direction +u0 = isv_machine.U[0:2, 0] / np.linalg.norm(isv_machine.U[0:2, 0]) +u1 = isv_machine.U[2:4, 0] / np.linalg.norm(isv_machine.U[2:4, 0]) +u2 = isv_machine.U[4:6, 0] / np.linalg.norm(isv_machine.U[4:6, 0]) + +figure, ax = plt.subplots() +plt.scatter(setosa[:, 0], setosa[:, 1], c="darkcyan", label="setosa") +plt.scatter( + versicolor[:, 0], versicolor[:, 1], c="goldenrod", label="versicolor" +) +plt.scatter(virginica[:, 0], virginica[:, 1], c="dimgrey", label="virginica") + +plt.scatter( + ubm.means[:, 0], + ubm.means[:, 1], + c="blue", + marker="x", + label="centroids - mle", +) +# plt.scatter(ubm.means[:, 0], ubm.means[:, 1], c="blue", +# marker=".", label="within class varibility", s=0.01) + +ax.arrow( + ubm.means[0, 0], + ubm.means[0, 1], + u0[0], + u0[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[1, 0], + ubm.means[1, 1], + u1[0], + u1[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[2, 0], + ubm.means[2, 1], + u2[0], + u2[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +plt.text( + ubm.means[0, 0] + u0[0], + ubm.means[0, 1] + u0[1] - 0.1, + r"$\mathbf{U}_1$", + fontsize=15, +) +plt.text( + ubm.means[1, 0] + u1[0], + ubm.means[1, 1] + u1[1] - 0.1, + r"$\mathbf{U}_2$", + fontsize=15, +) +plt.text( + ubm.means[2, 0] + u2[0], + ubm.means[2, 1] + u2[1] - 0.1, + r"$\mathbf{U}_3$", + fontsize=15, +) + +plt.xticks([], []) +plt.yticks([], []) + +# plt.grid(True) +plt.xlabel("Sepal length") +plt.ylabel("Petal width") +plt.legend() +plt.tight_layout() +plt.show() diff --git a/doc/plot/plot_JFA.py b/doc/plot/plot_JFA.py new file mode 100644 index 0000000000000000000000000000000000000000..a56ab32f584509ea1a50e9852df8a5386b4f1fea --- /dev/null +++ b/doc/plot/plot_JFA.py @@ -0,0 +1,201 @@ +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import load_iris + +import bob.learn.em + +np.random.seed(2) # FIXING A SEED + + +# GENERATING DATA +iris_data = load_iris() +X = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3])) +y = iris_data.target + + +setosa = X[iris_data.target == 0] +versicolor = X[iris_data.target == 1] +virginica = X[iris_data.target == 2] + +n_gaussians = 3 +r_U = 1 +r_V = 1 + + +# TRAINING THE PRIOR +ubm = bob.learn.em.GMMMachine(n_gaussians) +# Initializing with old bob initialization +ubm.means = np.array( + [ + [5.0048631, 0.26047739], + [5.83509503, 1.40530362], + [6.76257257, 1.98965356], + ] +) +ubm.variances = np.array( + [ + [0.11311728, 0.05183813], + [0.11587106, 0.08492455], + [0.20482993, 0.10438209], + ] +) + +ubm.weights = np.array([0.36, 0.36, 0.28]) + +jfa_machine = bob.learn.em.JFAMachine(r_U, r_V, ubm=ubm, em_iterations=50) + +# Initializing with old bob initialization +jfa_machine.U = np.array( + [[-0.150035, -0.44441, -1.67812, 2.47621, -0.52885, 0.659141]] +).T + +jfa_machine.Y = np.array( + [[-0.538446, 1.67376, -0.111288, 2.06948, 1.39563, -1.65004]] +).T +jfa_machine.D = np.array( + [0.732467, 0.281321, 0.543212, -0.512974, 1.04108, 0.835224] +) +jfa_machine = jfa_machine.fit(X, y) + + +# Variability direction U +u0 = jfa_machine.U[0:2, 0] / np.linalg.norm(jfa_machine.U[0:2, 0]) +u1 = jfa_machine.U[2:4, 0] / np.linalg.norm(jfa_machine.U[2:4, 0]) +u2 = jfa_machine.U[4:6, 0] / np.linalg.norm(jfa_machine.U[4:6, 0]) + + +# Variability direction V +v0 = jfa_machine.V[0:2, 0] / np.linalg.norm(jfa_machine.V[0:2, 0]) +v1 = jfa_machine.V[2:4, 0] / np.linalg.norm(jfa_machine.V[2:4, 0]) +v2 = jfa_machine.V[4:6, 0] / np.linalg.norm(jfa_machine.V[4:6, 0]) + + +figure, ax = plt.subplots() +plt.scatter(setosa[:, 0], setosa[:, 1], c="darkcyan", label="setosa") +plt.scatter( + versicolor[:, 0], versicolor[:, 1], c="goldenrod", label="versicolor" +) +plt.scatter(virginica[:, 0], virginica[:, 1], c="dimgrey", label="virginica") + +plt.scatter( + ubm.means[:, 0], + ubm.means[:, 1], + c="blue", + marker="x", + label="centroids - mle", +) +# plt.scatter(ubm.means[:, 0], ubm.means[:, 1], c="blue", +# marker=".", label="within class varibility", s=0.01) + +# U +ax.arrow( + ubm.means[0, 0], + ubm.means[0, 1], + u0[0], + u0[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[1, 0], + ubm.means[1, 1], + u1[0], + u1[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[2, 0], + ubm.means[2, 1], + u2[0], + u2[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +plt.text( + ubm.means[0, 0] + u0[0], + ubm.means[0, 1] + u0[1] - 0.1, + r"$\mathbf{U}_1$", + fontsize=15, +) +plt.text( + ubm.means[1, 0] + u1[0], + ubm.means[1, 1] + u1[1] - 0.1, + r"$\mathbf{U}_2$", + fontsize=15, +) +plt.text( + ubm.means[2, 0] + u2[0], + ubm.means[2, 1] + u2[1] - 0.1, + r"$\mathbf{U}_3$", + fontsize=15, +) + +# V +ax.arrow( + ubm.means[0, 0], + ubm.means[0, 1], + v0[0], + v0[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[1, 0], + ubm.means[1, 1], + v1[0], + v1[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +ax.arrow( + ubm.means[2, 0], + ubm.means[2, 1], + v2[0], + v2[1], + fc="k", + ec="k", + head_width=0.05, + head_length=0.1, +) +plt.text( + ubm.means[0, 0] + v0[0], + ubm.means[0, 1] + v0[1] - 0.1, + r"$\mathbf{V}_1$", + fontsize=15, +) +plt.text( + ubm.means[1, 0] + v1[0], + ubm.means[1, 1] + v1[1] - 0.1, + r"$\mathbf{V}_2$", + fontsize=15, +) +plt.text( + ubm.means[2, 0] + v2[0], + ubm.means[2, 1] + v2[1] - 0.1, + r"$\mathbf{V}_3$", + fontsize=15, +) + +plt.xticks([], []) +plt.yticks([], []) + +plt.xlabel("Sepal length") +plt.ylabel("Petal width") +plt.legend(loc=2) +plt.ylim([-1, 3.5]) + +plt.tight_layout() +plt.grid(True) +plt.show() diff --git a/doc/plot/plot_MAP.py b/doc/plot/plot_MAP.py new file mode 100644 index 0000000000000000000000000000000000000000..24e12ca400a20b8ca1db8427b141dd6549f71679 --- /dev/null +++ b/doc/plot/plot_MAP.py @@ -0,0 +1,59 @@ +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import load_iris + +import bob.learn.em + +np.random.seed(10) + +iris_data = load_iris() +data = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3])) +setosa = data[iris_data.target == 0] +versicolor = data[iris_data.target == 1] +virginica = data[iris_data.target == 2] + + +# Two clusters with +mle_machine = bob.learn.em.GMMMachine(3) +# Creating some fake means for the example +mle_machine.means = np.array([[5, 3], [4, 2], [7, 3.0]]) +mle_machine.variances = np.array([[0.1, 0.5], [0.2, 0.2], [0.7, 0.5]]) + + +# Creating some random data centered in +map_machine = bob.learn.em.GMMMachine( + 3, trainer="map", ubm=mle_machine, map_relevance_factor=4 +).fit(data) + + +figure, ax = plt.subplots() +# plt.scatter(data[:, 0], data[:, 1], c="olivedrab", label="new data") +plt.scatter(setosa[:, 0], setosa[:, 1], c="darkcyan", label="setosa") +plt.scatter( + versicolor[:, 0], versicolor[:, 1], c="goldenrod", label="versicolor" +) +plt.scatter(virginica[:, 0], virginica[:, 1], c="dimgrey", label="virginica") +plt.scatter( + mle_machine.means[:, 0], + mle_machine.means[:, 1], + c="blue", + marker="x", + label="prior centroids - mle", + s=60, +) +plt.scatter( + map_machine.means[:, 0], + map_machine.means[:, 1], + c="red", + marker="^", + label="adapted centroids - map", + s=60, +) +plt.legend() +plt.xticks([], []) +plt.yticks([], []) +ax.set_xlabel("Sepal length") +ax.set_ylabel("Petal width") +plt.tight_layout() +plt.show() diff --git a/doc/plot/plot_ML.py b/doc/plot/plot_ML.py index 32f664338343883bb27540f255d3bac433ba971d..410abf0bea768b1c0306e73f0659b83703736f5d 100644 --- a/doc/plot/plot_ML.py +++ b/doc/plot/plot_ML.py @@ -1,7 +1,7 @@ import logging import matplotlib.pyplot as plt -import numpy +import numpy as np from matplotlib.lines import Line2D from matplotlib.patches import Ellipse @@ -13,7 +13,7 @@ logger = logging.getLogger("bob.learn.em") logger.setLevel("DEBUG") iris_data = load_iris() -data = iris_data.data +data = np.column_stack((iris_data.data[:, 0], iris_data.data[:, 3])) setosa = data[iris_data.target == 0] versicolor = data[iris_data.target == 1] virginica = data[iris_data.target == 2] @@ -28,7 +28,6 @@ machine = GMMMachine( ) # Initialize the means with known values (optional, skips kmeans) -machine.means = numpy.array([[5, 3], [4, 2], [7, 3]], dtype=float) machine = machine.fit(data) @@ -48,14 +47,33 @@ ax.scatter( s=60, ) + +def draw_ellipse(position, covariance, ax=None, **kwargs): + """ + Draw an ellipse with a given position and covariance + """ + ax = ax or plt.gca() + + # Convert covariance to principal axes + if covariance.shape == (2, 2): + U, s, Vt = np.linalg.svd(covariance) + angle = np.degrees(np.arctan2(U[1, 0], U[0, 0])) + width, height = 2 * np.sqrt(s) + else: + angle = 0 + width, height = 2 * np.sqrt(covariance) + + # Draw the Ellipse + for nsig in range(1, 4): + ax.add_patch( + Ellipse(position, nsig * width, nsig * height, angle, **kwargs) + ) + + # Draw ellipses for covariance -for mean, variance in zip(machine.means, machine.variances): - eigvals, eigvecs = numpy.linalg.eig(numpy.diag(variance)) - axis = numpy.sqrt(eigvals) * numpy.sqrt(5.991) - angle = 180.0 * numpy.arctan(eigvecs[1][0] / eigvecs[1][1]) / numpy.pi - ax.add_patch( - Ellipse(mean, *axis, angle=angle, linewidth=1, fill=False, zorder=2) - ) +w_factor = 0.2 / np.max(machine.weights) +for w, pos, covar in zip(machine.weights, machine.means, machine.variances): + draw_ellipse(pos, covar, alpha=w * w_factor) # Plot details (legend, axis labels) plt.legend(