From eb4077e57cd174e359449f17ac406ed82509c3e1 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 20 Apr 2022 12:11:31 +0200 Subject: [PATCH] Refactor factor_analysis to accept n_classes --- bob/learn/em/factor_analysis.py | 368 +++++++++++++++++----- bob/learn/em/gmm.py | 12 +- bob/learn/em/kmeans.py | 28 +- bob/learn/em/test/test_factor_analysis.py | 14 +- bob/learn/em/utils.py | 33 ++ 5 files changed, 343 insertions(+), 112 deletions(-) create mode 100644 bob/learn/em/utils.py diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index cb5680a..09d158c 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -2,21 +2,37 @@ # @author: Tiago de Freitas Pereira +import functools import logging +import operator import dask import numpy as np +from dask.delayed import Delayed from sklearn.base import BaseEstimator from sklearn.utils.multiclass import unique_labels from .gmm import GMMMachine -from .kmeans import check_and_persist_dask_input from .linear_scoring import linear_scoring +from .utils import array_to_delayed_list, check_and_persist_dask_input logger = logging.getLogger(__name__) +def is_input_delayed(X): + """ + Check if the input is a list of dask delayed objects. + """ + if isinstance(X, (list, tuple)): + return is_input_delayed(X[0]) + + if isinstance(X, Delayed): + return True + else: + return False + + def mult_along_axis(A, B, axis): """ Magic function to multiply two arrays along a given axis. @@ -106,7 +122,7 @@ class FactorAnalysisBase(BaseEstimator): self.relevance_factor = relevance_factor - if ubm is not None: + if ubm is not None and ubm._means is not None: self.create_UVD() @property @@ -208,20 +224,38 @@ class FactorAnalysisBase(BaseEstimator): if self.ubm is None: logger.info("FA: Creating a new GMMMachine and training it.") self.ubm = GMMMachine(**self.ubm_kwargs) - self.ubm.fit(X) # GMMMachine.fit takes non-labeled data + self.ubm.fit(X) + + if self.ubm._means is None: + self.ubm.fit(X) # Initializing the state matrix if not hasattr(self, "_U") or not hasattr(self, "_D"): self.create_UVD() - def initialize_using_stats(self, ubm_projected_X, y): + def initialize_using_stats(self, ubm_projected_X, y, n_classes): # Accumulating 0th and 1st order statistics # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68 - # 0th order stats - n_acc = self._sum_n_statistics(ubm_projected_X, y) - # 1st order stats - f_acc = self._sum_f_statistics(ubm_projected_X, y) + if is_input_delayed(ubm_projected_X): + n_acc = [ + dask.delayed(self._sum_n_statistics)(xx, yy, n_classes) + for xx, yy in zip(ubm_projected_X, y) + ] + n_acc = dask.compute(*n_acc) + n_acc = functools.reduce(operator.iadd, n_acc) + + f_acc = [ + dask.delayed(self._sum_f_statistics)(xx, yy, n_classes) + for xx, yy in zip(ubm_projected_X, y) + ] + f_acc = dask.compute(*f_acc) + f_acc = functools.reduce(operator.iadd, f_acc) + else: + # 0th order stats + n_acc = self._sum_n_statistics(ubm_projected_X, y, n_classes) + # 1st order stats + f_acc = self._sum_f_statistics(ubm_projected_X, y, n_classes) return n_acc, f_acc @@ -258,7 +292,7 @@ class FactorAnalysisBase(BaseEstimator): else: self._V = 0 - def _sum_n_statistics(self, X, y): + def _sum_n_statistics(self, X, y, n_classes): """ Accumulates the 0th statistics for each client @@ -270,6 +304,9 @@ class FactorAnalysisBase(BaseEstimator): y: list of ints List of corresponding labels + n_classes: int + Number of classes + Returns ------- n_acc: array @@ -277,9 +314,7 @@ class FactorAnalysisBase(BaseEstimator): """ # 0th order stats - n_acc = np.zeros( - (self.estimate_number_of_classes(y), self.ubm.n_gaussians) - ) + n_acc = np.zeros((n_classes, self.ubm.n_gaussians)) # Iterate for each client for x_i, y_i in zip(X, y): @@ -288,7 +323,7 @@ class FactorAnalysisBase(BaseEstimator): return n_acc - def _sum_f_statistics(self, X, y): + def _sum_f_statistics(self, X, y, n_classes): """ Accumulates the 1st order statistics for each client @@ -299,6 +334,9 @@ class FactorAnalysisBase(BaseEstimator): y: list of ints List of corresponding labels + n_classes: int + Number of classes + Returns ------- f_acc: array @@ -309,7 +347,7 @@ class FactorAnalysisBase(BaseEstimator): # 1st order stats f_acc = np.zeros( ( - self.estimate_number_of_classes(y), + n_classes, self.ubm.n_gaussians, self.feature_dimension, ) @@ -410,7 +448,9 @@ class FactorAnalysisBase(BaseEstimator): return fn_x_ih - def update_x(self, X, y, UProd, latent_x, latent_y=None, latent_z=None): + def update_x( + self, X, y, n_classes, UProd, latent_x, latent_y=None, latent_z=None + ): """ Computes a new math:`E[x]` See equation (29) in [McCool2013]_ @@ -424,6 +464,9 @@ class FactorAnalysisBase(BaseEstimator): y: list of ints List of corresponding labels + n_classes: int + Number of classes + UProd: array Matrix containing U_c.T*inv(Sigma_c) @ U_c.T @@ -445,7 +488,7 @@ class FactorAnalysisBase(BaseEstimator): # U.T @ inv(Sigma) - See Eq(37) UTinvSigma = self._U.T / self.variance_supervector - session_offsets = np.zeros(self.estimate_number_of_classes(y)) + session_offsets = np.zeros(n_classes) # For each sample for x_i, y_i in zip(X, y): id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) @@ -789,7 +832,7 @@ class FactorAnalysisBase(BaseEstimator): return acc_D_A1, acc_D_A2 - def initialize_XYZ(self, y): + def initialize_XYZ(self, y, n_classes): """ Initialize E[x], E[y], E[z] state variables @@ -819,14 +862,12 @@ class FactorAnalysisBase(BaseEstimator): ) latent_y = ( - np.zeros((self.estimate_number_of_classes(y), self.r_V)) + np.zeros((n_classes, self.r_V)) if self.r_V and self.r_V > 0 else None ) - latent_z = np.zeros( - (self.estimate_number_of_classes(y), self.supervector_dimension) - ) + latent_z = np.zeros((n_classes, self.supervector_dimension)) return latent_x, latent_y, latent_z @@ -834,7 +875,9 @@ class FactorAnalysisBase(BaseEstimator): Estimating V and y """ - def update_y(self, X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc): + def update_y( + self, X, y, n_classes, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + ): """ Computes a new math:`E[y]` See equation (30) in [McCool2013]_ @@ -847,6 +890,9 @@ class FactorAnalysisBase(BaseEstimator): y: list of ints List of corresponding labels + n_classes: int + Number of classes + VProd: array Matrix representing V_c.T*inv(Sigma_c) @ V_c.T @@ -870,7 +916,7 @@ class FactorAnalysisBase(BaseEstimator): VTinvSigma = self._V.T / self.variance_supervector # Loops over the labels - for label in range(self.estimate_number_of_classes(y)): + for label in range(n_classes): id_plus_v_prod_i = self._compute_id_plus_vprod_i( n_acc[label], VProd ) @@ -1153,13 +1199,32 @@ class FactorAnalysisBase(BaseEstimator): def fit(self, X, y): input_is_dask, X = check_and_persist_dask_input(X) self.initialize(X) + y = np.squeeze(np.asarray(y)) if input_is_dask: - stats = [dask.delayed(self.ubm.transform)(xx) for xx in X] - stats = dask.compute(*stats) + chunks = X.chunks[0] + # chunk the y array similar to X into a list of numpy arrays + i, new_y = 0, [] + for chunk in chunks: + new_y.append(y[i : i + chunk]) + i += chunk + y = new_y + X = array_to_delayed_list(X, input_is_dask) + stats = [ + dask.delayed(self.ubm.transform_one_by_one)(xx).persist() + for xx in X + ] + # stats = dask.compute(*stats) + # stats = functools.reduce(operator.iadd, stats) else: - stats = [self.ubm.transform(xx) for xx in X] + stats = self.ubm.transform_one_by_one(X) del X # we don't need to persist X anymore - return self.fit_using_stats(stats, y) + + # if input_is_dask: + # dask.compute(dask.delayed(self.fit_using_stats)(stats, y)) + # else: + # self.fit_using_stats(stats, y) + self.fit_using_stats(stats, y) + return self class ISVMachine(FactorAnalysisBase): @@ -1213,18 +1278,30 @@ class ISVMachine(FactorAnalysisBase): **kwargs, ) - def e_step(self, X, y, n_acc, f_acc): + def e_step(self, X, y, n_classes, n_acc, f_acc): """ E-step of the EM algorithm """ # self.initialize_XYZ(y) UProd = self._compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y) + latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes) latent_y = None - latent_x = self.update_x(X, y, UProd, latent_x) + latent_x = self.update_x( + X=X, + y=y, + n_classes=n_classes, + UProd=UProd, + latent_x=latent_x, + ) latent_z = self.update_z( - X, y, latent_x, latent_y, latent_z, n_acc, f_acc + X=X, + y=y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, ) acc_U_A1, acc_U_A2 = self.compute_accumulators_U( X, y, UProd, latent_x, latent_y, latent_z @@ -1232,7 +1309,7 @@ class ISVMachine(FactorAnalysisBase): return acc_U_A1, acc_U_A2 - def m_step(self, acc_U_A1, acc_U_A2): + def m_step(self, acc_U_A1_acc_U_A2_list): """ ISV M-step. This updates `U` matrix @@ -1247,6 +1324,11 @@ class ISVMachine(FactorAnalysisBase): Accumulated statistics for U_A2(n_gaussians* feature_dimension, r_U) """ + acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list] + acc_U_A1 = functools.reduce(operator.iadd, acc_U_A1) + + acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] + acc_U_A2 = functools.reduce(operator.iadd, acc_U_A2) self.update_U(acc_U_A1, acc_U_A2) @@ -1269,14 +1351,36 @@ class ISVMachine(FactorAnalysisBase): """ y = np.asarray(y) + n_classes = self.estimate_number_of_classes(y) + n_acc, f_acc = self.initialize_using_stats(X, y, n_classes) - # TODO: Point of MAP-REDUCE - n_acc, f_acc = self.initialize_using_stats(X, y) for i in range(self.em_iterations): logger.info("U Training: Iteration %d", i + 1) # TODO: Point of MAP-REDUCE - acc_U_A1, acc_U_A2 = self.e_step(X, y, n_acc, f_acc) - self.m_step(acc_U_A1, acc_U_A2) + if is_input_delayed(X): + acc_U_A1_acc_U_A2_list = [ + dask.delayed(self.e_step)( + X=xx, + y=yy, + n_classes=n_classes, + n_acc=n_acc, + f_acc=f_acc, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step)( + acc_U_A1_acc_U_A2_list + ) + dask.compute(delayed_em_step) + else: + acc_U_A1, acc_U_A2 = self.e_step( + X=X, + y=y, + n_classes=n_classes, + n_acc=n_acc, + f_acc=f_acc, + ) + self.m_step([(acc_U_A1, acc_U_A2)]) return self @@ -1284,7 +1388,7 @@ class ISVMachine(FactorAnalysisBase): ubm_projected_X = self.ubm.transform(X) return self.estimate_ux(ubm_projected_X) - def enroll(self, X, iterations=1): + def enroll_using_stats(self, X, iterations=1): """ Enrolls a new client In ISV, the enrolment is defined as: :math:`m + Dz` with the latent variables `z` @@ -1295,6 +1399,7 @@ class ISVMachine(FactorAnalysisBase): X : list of :py:class:`bob.learn.em.GMMStats` List of statistics to be enrolled + iterations : int Number of iterations to perform @@ -1306,22 +1411,36 @@ class ISVMachine(FactorAnalysisBase): """ # We have only one class for enrollment y = list(np.zeros(len(X), dtype=np.int32)) - n_acc = self._sum_n_statistics(X, y=y) - f_acc = self._sum_f_statistics(X, y=y) + n_acc = self._sum_n_statistics(X, y=y, n_classes=1) + f_acc = self._sum_f_statistics(X, y=y, n_classes=1) UProd = self._compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y) + latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=1) latent_y = None for i in range(iterations): logger.info("Enrollment: Iteration %d", i + 1) - latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z) + latent_x = self.update_x( + X=X, + y=y, + n_classes=1, + UProd=UProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + ) latent_z = self.update_z( - X, y, latent_x, latent_y, latent_z, n_acc, f_acc + X=X, + y=y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, ) return latent_z - def enroll_with_array(self, X, iterations=1): + def enroll(self, X, iterations=1): """ Enrolls a new client using a numpy array as input @@ -1339,7 +1458,7 @@ class ISVMachine(FactorAnalysisBase): z """ - return self.enroll([self.ubm.transform(X)], iterations) + return self.enroll_using_stats([self.ubm.transform(X)], iterations) def score_using_stats(self, latent_z, data): """ @@ -1431,7 +1550,7 @@ class JFAMachine(FactorAnalysisBase): **kwargs, ) - def e_step_v(self, X, y, n_acc, f_acc): + def e_step_v(self, X, y, n_classes, n_acc, f_acc): """ ISV E-step for the V matrix. @@ -1444,6 +1563,9 @@ class JFAMachine(FactorAnalysisBase): y: list of int List of labels + n_classes: int + Number of classes + n_acc: array Accumulated 0th-order statistics @@ -1464,16 +1586,33 @@ class JFAMachine(FactorAnalysisBase): VProd = self._compute_vprod() - latent_x, latent_y, latent_z = self.initialize_XYZ(y) + latent_x, latent_y, latent_z = self.initialize_XYZ( + y, n_classes=n_classes + ) # UPDATE Y, X AND FINALLY Z latent_y = self.update_y( - X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + X=X, + y=y, + n_classes=n_classes, + VProd=VProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, ) acc_V_A1, acc_V_A2 = self.compute_accumulators_V( - X, y, VProd, n_acc, f_acc, latent_x, latent_y, latent_z + X=X, + y=y, + VProd=VProd, + n_acc=n_acc, + f_acc=f_acc, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, ) return acc_V_A1, acc_V_A2 @@ -1508,7 +1647,7 @@ class JFAMachine(FactorAnalysisBase): (self.ubm.n_gaussians * self.feature_dimension, self.r_V) ) - def finalize_v(self, X, y, n_acc, f_acc): + def finalize_v(self, X, y, n_classes, n_acc, f_acc): """ Compute for the last time `E[y]` @@ -1521,6 +1660,9 @@ class JFAMachine(FactorAnalysisBase): y: list of int List of labels + n_classes: int + Number of classes + n_acc: array Accumulated 0th-order statistics @@ -1535,16 +1677,26 @@ class JFAMachine(FactorAnalysisBase): """ VProd = self._compute_vprod() - latent_x, latent_y, latent_z = self.initialize_XYZ(y) + latent_x, latent_y, latent_z = self.initialize_XYZ( + y, n_classes=n_classes + ) # UPDATE Y, X AND FINALLY Z latent_y = self.update_y( - X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + X=X, + y=y, + n_classes=n_classes, + VProd=VProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, ) return latent_y - def e_step_u(self, X, y, latent_y): + def e_step_u(self, X, y, n_classes, latent_y): """ ISV E-step for the U matrix. @@ -1573,12 +1725,24 @@ class JFAMachine(FactorAnalysisBase): """ # self.initialize_XYZ(y) UProd = self._compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y) + latent_x, _, latent_z = self.initialize_XYZ(y, n_classes) - latent_x = self.update_x(X, y, UProd, latent_x, latent_y) + latent_x = self.update_x( + X=X, + y=y, + n_classes=n_classes, + UProd=UProd, + latent_x=latent_x, + latent_y=latent_y, + ) acc_U_A1, acc_U_A2 = self.compute_accumulators_U( - X, y, UProd, latent_x, latent_y, latent_z + X=X, + y=y, + UProd=UProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, ) return acc_U_A1, acc_U_A2 @@ -1605,6 +1769,7 @@ class JFAMachine(FactorAnalysisBase): self, X, y, + n_classes, latent_y, ): """ @@ -1619,6 +1784,9 @@ class JFAMachine(FactorAnalysisBase): y: list of int List of labels + n_classes: int + Number of classes + latent_y: array E[y] latent variable @@ -1629,15 +1797,20 @@ class JFAMachine(FactorAnalysisBase): """ UProd = self._compute_uprod() - latent_x, _, _ = self.initialize_XYZ(y) + latent_x, _, _ = self.initialize_XYZ(y, n_classes=n_classes) latent_x = self.update_x( - X, y, UProd, latent_x=latent_x, latent_y=latent_y + X=X, + y=y, + n_classes=n_classes, + UProd=UProd, + latent_x=latent_x, + latent_y=latent_y, ) return latent_x - def e_step_d(self, X, y, latent_x, latent_y, n_acc, f_acc): + def e_step_d(self, X, y, n_classes, latent_x, latent_y, n_acc, f_acc): """ ISV E-step for the U matrix. @@ -1650,6 +1823,9 @@ class JFAMachine(FactorAnalysisBase): y: list of int List of labels + n_classes: int + Number of classes + latent_x: array E(x) latent variable @@ -1677,7 +1853,7 @@ class JFAMachine(FactorAnalysisBase): """ - _, _, latent_z = self.initialize_XYZ(y) + _, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes) latent_z = self.update_z( X, @@ -1712,7 +1888,7 @@ class JFAMachine(FactorAnalysisBase): """ self._D = acc_D_A2 / acc_D_A1 - def enroll(self, X, iterations=1): + def enroll_using_stats(self, X, iterations=1): """ Enrolls a new client. In JFA the enrolment is defined as: :math:`m + Vy + Dz` with the latent variables `y` and `z` @@ -1734,27 +1910,49 @@ class JFAMachine(FactorAnalysisBase): """ # We have only one class for enrollment y = list(np.zeros(len(X), dtype=np.int32)) - n_acc = self._sum_n_statistics(X, y=y) - f_acc = self._sum_f_statistics(X, y=y) + n_acc = self._sum_n_statistics(X, y=y, n_classes=1) + f_acc = self._sum_f_statistics(X, y=y, n_classes=1) UProd = self._compute_uprod() VProd = self._compute_vprod() - latent_x, latent_y, latent_z = self.initialize_XYZ(y) + latent_x, latent_y, latent_z = self.initialize_XYZ(y, n_classes=1) for i in range(iterations): logger.info("Enrollment: Iteration %d", i + 1) latent_y = self.update_y( - X, y, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + X=X, + y=y, + n_classes=1, + VProd=VProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, + ) + latent_x = self.update_x( + X=X, + y=y, + n_classes=1, + UProd=UProd, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, ) - latent_x = self.update_x(X, y, UProd, latent_x, latent_y, latent_z) latent_z = self.update_z( - X, y, latent_x, latent_y, latent_z, n_acc, f_acc + X=X, + y=y, + latent_x=latent_x, + latent_y=latent_y, + latent_z=latent_z, + n_acc=n_acc, + f_acc=f_acc, ) # The latent variables are wrapped in to 2axis arrays return latent_y[0], latent_z[0] - def enroll_with_array(self, X, iterations=1): + def enroll(self, X, iterations=1): """ Enrolls a new client using a numpy array as input @@ -1772,7 +1970,7 @@ class JFAMachine(FactorAnalysisBase): z """ - return self.enroll([self.ubm.transform(X)], iterations) + return self.enroll_using_stats([self.ubm.transform(X)], iterations) def fit_using_stats(self, X, y): """ @@ -1801,33 +1999,55 @@ class JFAMachine(FactorAnalysisBase): self.create_UVD() y = np.asarray(y) + n_classes = self.estimate_number_of_classes(y) # TODO: Point of MAP-REDUCE - n_acc, f_acc = self.initialize_using_stats(X, y) + n_acc, f_acc = self.initialize_using_stats(X, y, n_classes=n_classes) # Updating V for i in range(self.em_iterations): logger.info("V Training: Iteration %d", i + 1) # TODO: Point of MAP-REDUCE - acc_V_A1, acc_V_A2 = self.e_step_v(X, y, n_acc, f_acc) + acc_V_A1, acc_V_A2 = self.e_step_v( + X=X, + y=y, + n_classes=n_classes, + n_acc=n_acc, + f_acc=f_acc, + ) self.m_step_v(acc_V_A1, acc_V_A2) - latent_y = self.finalize_v(X, y, n_acc, f_acc) + latent_y = self.finalize_v( + X=X, y=y, n_classes=n_classes, n_acc=n_acc, f_acc=f_acc + ) # Updating U for i in range(self.em_iterations): logger.info("U Training: Iteration %d", i + 1) # TODO: Point of MAP-REDUCE - acc_U_A1, acc_U_A2 = self.e_step_u(X, y, latent_y) + acc_U_A1, acc_U_A2 = self.e_step_u( + X=X, + y=y, + n_classes=n_classes, + latent_y=latent_y, + ) self.m_step_u(acc_U_A1, acc_U_A2) - latent_x = self.finalize_u(X, y, latent_y) + latent_x = self.finalize_u( + X=X, y=y, n_classes=n_classes, latent_y=latent_y + ) # Updating D for i in range(self.em_iterations): logger.info("D Training: Iteration %d", i + 1) # TODO: Point of MAP-REDUCE acc_D_A1, acc_D_A2 = self.e_step_d( - X, y, latent_x, latent_y, n_acc, f_acc + X=X, + y=y, + n_classes=n_classes, + latent_x=latent_x, + latent_y=latent_y, + n_acc=n_acc, + f_acc=f_acc, ) self.m_step_d(acc_D_A1, acc_D_A2) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index ce8581a..394a930 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -18,11 +18,8 @@ import numpy as np from h5py import File as HDF5File from sklearn.base import BaseEstimator -from .kmeans import ( - KMeansMachine, - array_to_delayed_list, - check_and_persist_dask_input, -) +from .kmeans import KMeansMachine +from .utils import array_to_delayed_list, check_and_persist_dask_input logger = logging.getLogger(__name__) @@ -853,10 +850,13 @@ class GMMMachine(BaseEstimator): ) return self - def transform(self, X, **kwargs): + def transform(self, X): """Returns the statistics for `X`.""" return e_step(data=X, machine=self) + def transform_one_by_one(self, X): + return [e_step(data=xx, machine=self) for xx in X] + def _more_tags(self): return { "stateless": False, diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index 8a17af3..61af76e 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -15,6 +15,8 @@ import scipy.spatial.distance from dask_ml.cluster.k_means import k_init from sklearn.base import BaseEstimator +from .utils import array_to_delayed_list, check_and_persist_dask_input + logger = logging.getLogger(__name__) @@ -171,32 +173,6 @@ def reduce_indices_means_vars(stats): return variances, weights -def check_and_persist_dask_input(data): - # check if input is a dask array. If so, persist and rebalance data - input_is_dask = False - if isinstance(data, da.Array): - data: da.Array = data.persist() - input_is_dask = True - # if there is a dask distributed client, rebalance data - try: - client = dask.distributed.Client.current() - client.rebalance() - except ValueError: - pass - - else: - data = np.asarray(data) - return input_is_dask, data - - -def array_to_delayed_list(data, input_is_dask): - # If input is a dask array, convert to delayed chunks - if input_is_dask: - data = data.to_delayed().ravel().tolist() - logger.debug(f"Got {len(data)} chunks.") - return data - - class KMeansMachine(BaseEstimator): """Stores the k-means clusters parameters (centroid of each cluster). diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index 3b9ca21..08ae5ab 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -194,7 +194,7 @@ def test_JFATrainAndEnrol(): gse2.sum_px = Fe[:, 1].reshape(2, 3) gse = [gse1, gse2] - latent_y, latent_z = it.enroll(gse, 5) + latent_y, latent_z = it.enroll_using_stats(gse, 5) y_ref = np.array([0.555991469319657, 0.002773650670010], "float64") z_ref = np.array( @@ -302,7 +302,7 @@ def test_ISVTrainAndEnrol(): gse2.sum_px = Fe[:, 1].reshape(2, 3) gse = [gse1, gse2] - latent_z = it.enroll(gse, 5) + latent_z = it.enroll_using_stats(gse, 5) np.testing.assert_allclose(latent_z, z_ref, rtol=eps, atol=1e-8) @@ -320,13 +320,14 @@ def test_JFATrainInitialize(): it = JFAMachine(2, 2, em_iterations=10, ubm=ubm) # first round - it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) + n_classes = it.estimate_number_of_classes(TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) u1 = it.U v1 = it.V d1 = it.D # second round - it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) u2 = it.U v2 = it.V d2 = it.D @@ -350,12 +351,13 @@ def test_ISVTrainInitialize(): it = ISVMachine(2, em_iterations=10, ubm=ubm) # it.rng = rng - it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) + n_classes = it.estimate_number_of_classes(TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) u1 = copy.deepcopy(it.U) d1 = copy.deepcopy(it.D) # second round - it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y) + it.initialize_using_stats(TRAINING_STATS_X, TRAINING_STATS_y, n_classes) u2 = it.U d2 = it.D diff --git a/bob/learn/em/utils.py b/bob/learn/em/utils.py new file mode 100644 index 0000000..108e797 --- /dev/null +++ b/bob/learn/em/utils.py @@ -0,0 +1,33 @@ +import logging + +import dask +import dask.array as da +import numpy as np + +logger = logging.getLogger(__name__) + + +def check_and_persist_dask_input(data): + # check if input is a dask array. If so, persist and rebalance data + input_is_dask = False + if isinstance(data, da.Array): + data: da.Array = data.persist() + input_is_dask = True + # if there is a dask distributed client, rebalance data + try: + client = dask.distributed.Client.current() + client.rebalance() + except ValueError: + pass + + else: + data = np.asarray(data) + return input_is_dask, data + + +def array_to_delayed_list(data, input_is_dask): + # If input is a dask array, convert to delayed chunks + if input_is_dask: + data = data.to_delayed().ravel().tolist() + logger.debug(f"Got {len(data)} chunks.") + return data -- GitLab