From 5f49f8e225bf495613587ad0a87e7ab7648d4065 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 20 Apr 2022 18:59:38 +0200 Subject: [PATCH] [factor_analysis] somehow it works with dask all of a sudden --- bob/learn/em/factor_analysis.py | 450 ++++++++++++++-------- bob/learn/em/gmm.py | 2 +- bob/learn/em/test/test_factor_analysis.py | 7 +- bob/learn/em/utils.py | 5 +- 4 files changed, 305 insertions(+), 159 deletions(-) diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py index 09d158c..6dd609e 100644 --- a/bob/learn/em/factor_analysis.py +++ b/bob/learn/em/factor_analysis.py @@ -9,30 +9,52 @@ import operator import dask import numpy as np +from dask.array.core import Array from dask.delayed import Delayed from sklearn.base import BaseEstimator from sklearn.utils.multiclass import unique_labels from .gmm import GMMMachine from .linear_scoring import linear_scoring -from .utils import array_to_delayed_list, check_and_persist_dask_input +from .utils import check_and_persist_dask_input logger = logging.getLogger(__name__) -def is_input_delayed(X): +def is_input_dask_nested(X): """ - Check if the input is a list of dask delayed objects. + Check if the input is a dask delayed or array or a (nested) list of dask + delayed or array objects. """ if isinstance(X, (list, tuple)): - return is_input_delayed(X[0]) + return is_input_dask_nested(X[0]) - if isinstance(X, Delayed): + if isinstance(X, (Delayed, Array)): return True else: return False +def check_dask_input_samples_per_class(X, y): + input_is_dask = is_input_dask_nested(X) + + if input_is_dask: + n_classes = len(y) + n_samples_per_class = [len(yy) for yy in y] + else: + unique_labels_y = unique_labels(y) + n_classes = len(unique_labels_y) + n_samples_per_class = [ + sum(y == class_id) for class_id in unique_labels_y + ] + return input_is_dask, n_classes, n_samples_per_class + + +def reduce_iadd(a): + """Reduces a list by adding all elements into the first element""" + return functools.reduce(operator.iadd, a) + + def mult_along_axis(A, B, axis): """ Magic function to multiply two arrays along a given axis. @@ -237,20 +259,19 @@ class FactorAnalysisBase(BaseEstimator): # Accumulating 0th and 1st order statistics # https://gitlab.idiap.ch/bob/bob.learn.em/-/blob/da92d0e5799d018f311f1bf5cdd5a80e19e142ca/bob/learn/em/cpp/ISVTrainer.cpp#L68 - if is_input_delayed(ubm_projected_X): + if is_input_dask_nested(ubm_projected_X): n_acc = [ dask.delayed(self._sum_n_statistics)(xx, yy, n_classes) for xx, yy in zip(ubm_projected_X, y) ] - n_acc = dask.compute(*n_acc) - n_acc = functools.reduce(operator.iadd, n_acc) f_acc = [ dask.delayed(self._sum_f_statistics)(xx, yy, n_classes) for xx, yy in zip(ubm_projected_X, y) ] - f_acc = dask.compute(*f_acc) - f_acc = functools.reduce(operator.iadd, f_acc) + n_acc, f_acc = dask.compute(n_acc, f_acc) + n_acc = reduce_iadd(n_acc) + f_acc = reduce_iadd(f_acc) else: # 0th order stats n_acc = self._sum_n_statistics(ubm_projected_X, y, n_classes) @@ -374,7 +395,7 @@ class FactorAnalysisBase(BaseEstimator): Class id to return the statistics for """ X = np.array(X) - return list(X[np.where(np.array(y) == i)[0]]) + return list(X[np.array(y) == i]) """ Estimating U and x @@ -406,7 +427,7 @@ class FactorAnalysisBase(BaseEstimator): # TODO: make the invertion matrix function as a parameter return np.linalg.inv(I + (UProd * n_i[:, None, None]).sum(axis=0)) - def _computefn_x_ih(self, x_i, latent_z_i=None, latent_y_i=None): + def _compute_fn_x_ih(self, x_i, latent_z_i=None, latent_y_i=None): """ Computes Fn_x_ih = N_{i,h}*(o_{i,h} - m - D*z_{i} - V*y_{i}) Check equation (29) in [McCool2013]_ @@ -448,8 +469,8 @@ class FactorAnalysisBase(BaseEstimator): return fn_x_ih - def update_x( - self, X, y, n_classes, UProd, latent_x, latent_y=None, latent_z=None + def compute_latent_x( + self, *, X, y, n_classes, UProd, latent_y=None, latent_z=None ): """ Computes a new math:`E[x]` See equation (29) in [McCool2013]_ @@ -464,15 +485,9 @@ class FactorAnalysisBase(BaseEstimator): y: list of ints List of corresponding labels - n_classes: int - Number of classes - UProd: array Matrix containing U_c.T*inv(Sigma_c) @ U_c.T - latent_x: array - E(x) latent variable - latent_y: array E(y) latent variable @@ -482,27 +497,50 @@ class FactorAnalysisBase(BaseEstimator): Returns ------- Returns the new latent_x - """ # U.T @ inv(Sigma) - See Eq(37) UTinvSigma = self._U.T / self.variance_supervector - session_offsets = np.zeros(n_classes) + if is_input_dask_nested(X): + latent_x = [ + dask.delayed(self._compute_latent_x_per_class)( + X_i=X_i, + UProd=UProd, + UTinvSigma=UTinvSigma, + latent_y_i=latent_y[y_i] if latent_y is not None else None, + latent_z_i=latent_z[y_i] if latent_z is not None else None, + ) + for y_i, X_i in enumerate(X) + ] + latent_x = dask.compute(*latent_x) + else: + latent_x = [None] * n_classes + for y_i in unique_labels(y): + latent_x[y_i] = self._compute_latent_x_per_class( + X_i=np.array(X)[y == y_i], + UProd=UProd, + UTinvSigma=UTinvSigma, + latent_y_i=latent_y[y_i] if latent_y is not None else None, + latent_z_i=latent_z[y_i] if latent_z is not None else None, + ) + + return latent_x + + def _compute_latent_x_per_class( + self, *, X_i, UProd, UTinvSigma, latent_y_i, latent_z_i + ): # For each sample - for x_i, y_i in zip(X, y): + latent_x_i = [] + for x_i in X_i: id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) - latent_z_i = latent_z[y_i] if latent_z is not None else None - latent_y_i = latent_y[y_i] if latent_y is not None else None - fn_x_ih = self._computefn_x_ih( + fn_x_ih = self._compute_fn_x_ih( x_i, latent_z_i=latent_z_i, latent_y_i=latent_y_i ) - latent_x[y_i][:, int(session_offsets[y_i])] = id_plus_prod_ih @ ( - UTinvSigma @ fn_x_ih - ) - session_offsets[y_i] += 1 - return latent_x + latent_x_i.append(id_plus_prod_ih @ (UTinvSigma @ fn_x_ih)) + latent_x_i = np.swapaxes(latent_x_i, 0, 1) + return latent_x_i def update_U(self, acc_U_A1, acc_U_A2): """ @@ -612,7 +650,7 @@ class FactorAnalysisBase(BaseEstimator): id_plus_prod_ih = self._compute_id_plus_u_prod_ih(x_i, UProd) latent_z_i = latent_z[y_i] if latent_z is not None else None latent_y_i = latent_y[y_i] if latent_y is not None else None - fn_x_ih = self._computefn_x_ih( + fn_x_ih = self._compute_fn_x_ih( x_i, latent_y_i=latent_y_i, latent_z_i=latent_z_i ) @@ -832,7 +870,7 @@ class FactorAnalysisBase(BaseEstimator): return acc_D_A1, acc_D_A2 - def initialize_XYZ(self, y, n_classes): + def initialize_XYZ(self, n_samples_per_class): """ Initialize E[x], E[y], E[z] state variables @@ -841,7 +879,7 @@ class FactorAnalysisBase(BaseEstimator): Eq. (37) - latent_y = + latent_y = (n_classes, r_V) or None Eq. (36) latent_x = (n_classes, r_U, n_sessions) @@ -851,16 +889,10 @@ class FactorAnalysisBase(BaseEstimator): # x (Eq. 36) # (n_classes, r_U, n_samples ) latent_x = [] - for y_i in set(y): - latent_x.append( - np.zeros( - ( - self.r_U, - np.sum(y == y_i), - ) - ) - ) + for n_s in n_samples_per_class: + latent_x.append(np.zeros((self.r_U, n_s))) + n_classes = len(n_samples_per_class) latent_y = ( np.zeros((n_classes, self.r_V)) if self.r_V and self.r_V > 0 @@ -876,7 +908,17 @@ class FactorAnalysisBase(BaseEstimator): """ def update_y( - self, X, y, n_classes, VProd, latent_x, latent_y, latent_z, n_acc, f_acc + self, + *, + X, + y, + n_classes, + VProd, + latent_x, + latent_y, + latent_z, + n_acc, + f_acc, ): """ Computes a new math:`E[y]` See equation (30) in [McCool2013]_ @@ -915,20 +957,55 @@ class FactorAnalysisBase(BaseEstimator): # V.T / sigma VTinvSigma = self._V.T / self.variance_supervector - # Loops over the labels - for label in range(n_classes): - id_plus_v_prod_i = self._compute_id_plus_vprod_i( - n_acc[label], VProd - ) - X_i = self._get_statistics_by_class_id(X, y, label) - fn_y_i = self._compute_fn_y_i( - X_i, - latent_x[label], - latent_z[label], - n_acc[label], - f_acc[label], - ) - latent_y[label] = (VTinvSigma @ fn_y_i) @ id_plus_v_prod_i + if is_input_dask_nested(X): + latent_y = [ + dask.delayed(self._latent_y_per_class)( + X_i=X_i, + n_acc_i=n_acc[label], + f_acc_i=f_acc[label], + VProd=VProd, + VTinvSigma=VTinvSigma, + latent_x_i=latent_x[label], + latent_z_i=latent_z[label], + ) + for label, X_i in enumerate(X) + ] + latent_y = dask.compute(*latent_y) + else: + # Loops over the labels + for label in range(n_classes): + X_i = self._get_statistics_by_class_id(X, y, label) + latent_y[label] = self._latent_y_per_class( + X_i=X_i, + n_acc_i=n_acc[label], + f_acc_i=f_acc[label], + VProd=VProd, + VTinvSigma=VTinvSigma, + latent_x_i=latent_x[label], + latent_z_i=latent_z[label], + ) + return latent_y + + def _latent_y_per_class( + self, + *, + X_i, + n_acc_i, + f_acc_i, + VProd, + VTinvSigma, + latent_x_i, + latent_z_i, + ): + id_plus_v_prod_i = self._compute_id_plus_vprod_i(n_acc_i, VProd) + fn_y_i = self._compute_fn_y_i( + X_i, + latent_x_i, + latent_z_i, + n_acc_i, + f_acc_i, + ) + latent_y = (VTinvSigma @ fn_y_i) @ id_plus_v_prod_i return latent_y def _compute_id_plus_vprod_i(self, n_acc_i, VProd): @@ -1197,32 +1274,27 @@ class FactorAnalysisBase(BaseEstimator): return self.score_using_stats(model, self.ubm.transform(data)) def fit(self, X, y): - input_is_dask, X = check_and_persist_dask_input(X) - self.initialize(X) + input_is_dask, X = check_and_persist_dask_input(X, persist=False) y = np.squeeze(np.asarray(y)) + + self.initialize(X) + if input_is_dask: - chunks = X.chunks[0] - # chunk the y array similar to X into a list of numpy arrays - i, new_y = 0, [] - for chunk in chunks: - new_y.append(y[i : i + chunk]) - i += chunk - y = new_y - X = array_to_delayed_list(X, input_is_dask) + # split the X array based on the classes + X_new, y_new = [], [] + for class_id in unique_labels(y): + class_indices = y == class_id + X_new.append(X[class_indices]) + y_new.append(y[class_indices]) + X, y = X_new, y_new + stats = [ - dask.delayed(self.ubm.transform_one_by_one)(xx).persist() + dask.delayed(self.ubm.stats_per_sample)(xx).persist() for xx in X ] - # stats = dask.compute(*stats) - # stats = functools.reduce(operator.iadd, stats) else: - stats = self.ubm.transform_one_by_one(X) - del X # we don't need to persist X anymore + stats = self.ubm.stats_per_sample(X) - # if input_is_dask: - # dask.compute(dask.delayed(self.fit_using_stats)(stats, y)) - # else: - # self.fit_using_stats(stats, y) self.fit_using_stats(stats, y) return self @@ -1278,21 +1350,22 @@ class ISVMachine(FactorAnalysisBase): **kwargs, ) - def e_step(self, X, y, n_classes, n_acc, f_acc): + def e_step(self, X, y, n_samples_per_class, n_acc, f_acc): """ E-step of the EM algorithm """ # self.initialize_XYZ(y) UProd = self._compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes) + _, _, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) latent_y = None - latent_x = self.update_x( + latent_x = self.compute_latent_x( X=X, y=y, - n_classes=n_classes, + n_classes=len(n_samples_per_class), UProd=UProd, - latent_x=latent_x, ) latent_z = self.update_z( X=X, @@ -1325,10 +1398,10 @@ class ISVMachine(FactorAnalysisBase): """ acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list] - acc_U_A1 = functools.reduce(operator.iadd, acc_U_A1) + acc_U_A1 = reduce_iadd(acc_U_A1) acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] - acc_U_A2 = functools.reduce(operator.iadd, acc_U_A2) + acc_U_A2 = reduce_iadd(acc_U_A2) self.update_U(acc_U_A1, acc_U_A2) @@ -1349,38 +1422,38 @@ class ISVMachine(FactorAnalysisBase): Returns self. """ + ( + input_is_dask, + n_classes, + n_samples_per_class, + ) = check_dask_input_samples_per_class(X, y) - y = np.asarray(y) - n_classes = self.estimate_number_of_classes(y) n_acc, f_acc = self.initialize_using_stats(X, y, n_classes) for i in range(self.em_iterations): logger.info("U Training: Iteration %d", i + 1) - # TODO: Point of MAP-REDUCE - if is_input_delayed(X): - acc_U_A1_acc_U_A2_list = [ + if input_is_dask: + e_step_output = [ dask.delayed(self.e_step)( X=xx, y=yy, - n_classes=n_classes, + n_samples_per_class=n_samples_per_class, n_acc=n_acc, f_acc=f_acc, ) for xx, yy in zip(X, y) ] - delayed_em_step = dask.delayed(self.m_step)( - acc_U_A1_acc_U_A2_list - ) + delayed_em_step = dask.delayed(self.m_step)(e_step_output) dask.compute(delayed_em_step) else: - acc_U_A1, acc_U_A2 = self.e_step( + e_step_output = self.e_step( X=X, y=y, - n_classes=n_classes, + n_samples_per_class=n_samples_per_class, n_acc=n_acc, f_acc=f_acc, ) - self.m_step([(acc_U_A1, acc_U_A2)]) + self.m_step([e_step_output]) return self @@ -1415,16 +1488,15 @@ class ISVMachine(FactorAnalysisBase): f_acc = self._sum_f_statistics(X, y=y, n_classes=1) UProd = self._compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y, n_classes=1) + _, _, latent_z = self.initialize_XYZ(n_samples_per_class=[len(X)]) latent_y = None for i in range(iterations): logger.info("Enrollment: Iteration %d", i + 1) - latent_x = self.update_x( + latent_x = self.compute_latent_x( X=X, y=y, n_classes=1, UProd=UProd, - latent_x=latent_x, latent_y=latent_y, latent_z=latent_z, ) @@ -1550,7 +1622,7 @@ class JFAMachine(FactorAnalysisBase): **kwargs, ) - def e_step_v(self, X, y, n_classes, n_acc, f_acc): + def e_step_v(self, X, y, n_samples_per_class, n_acc, f_acc): """ ISV E-step for the V matrix. @@ -1587,11 +1659,12 @@ class JFAMachine(FactorAnalysisBase): VProd = self._compute_vprod() latent_x, latent_y, latent_z = self.initialize_XYZ( - y, n_classes=n_classes + n_samples_per_class=n_samples_per_class ) # UPDATE Y, X AND FINALLY Z + n_classes = len(n_samples_per_class) latent_y = self.update_y( X=X, y=y, @@ -1617,7 +1690,7 @@ class JFAMachine(FactorAnalysisBase): return acc_V_A1, acc_V_A2 - def m_step_v(self, acc_V_A1, acc_V_A2): + def m_step_v(self, acc_V_A1_acc_V_A2_list): """ `V` Matrix M-step. This updates the `V` matrix @@ -1632,6 +1705,11 @@ class JFAMachine(FactorAnalysisBase): Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) """ + acc_V_A1 = [acc[0] for acc in acc_V_A1_acc_V_A2_list] + acc_V_A1 = reduce_iadd(acc_V_A1) + + acc_V_A2 = [acc[1] for acc in acc_V_A1_acc_V_A2_list] + acc_V_A2 = reduce_iadd(acc_V_A2) # Inverting A1 over the zero axis # https://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy @@ -1647,7 +1725,7 @@ class JFAMachine(FactorAnalysisBase): (self.ubm.n_gaussians * self.feature_dimension, self.r_V) ) - def finalize_v(self, X, y, n_classes, n_acc, f_acc): + def finalize_v(self, X, y, n_samples_per_class, n_acc, f_acc): """ Compute for the last time `E[y]` @@ -1678,11 +1756,12 @@ class JFAMachine(FactorAnalysisBase): VProd = self._compute_vprod() latent_x, latent_y, latent_z = self.initialize_XYZ( - y, n_classes=n_classes + n_samples_per_class=n_samples_per_class ) # UPDATE Y, X AND FINALLY Z + n_classes = len(n_samples_per_class) latent_y = self.update_y( X=X, y=y, @@ -1696,7 +1775,7 @@ class JFAMachine(FactorAnalysisBase): ) return latent_y - def e_step_u(self, X, y, n_classes, latent_y): + def e_step_u(self, X, y, n_samples_per_class, latent_y): """ ISV E-step for the U matrix. @@ -1725,14 +1804,16 @@ class JFAMachine(FactorAnalysisBase): """ # self.initialize_XYZ(y) UProd = self._compute_uprod() - latent_x, _, latent_z = self.initialize_XYZ(y, n_classes) + latent_x, _, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) - latent_x = self.update_x( + n_classes = len(n_samples_per_class) + latent_x = self.compute_latent_x( X=X, y=y, n_classes=n_classes, UProd=UProd, - latent_x=latent_x, latent_y=latent_y, ) @@ -1747,7 +1828,7 @@ class JFAMachine(FactorAnalysisBase): return acc_U_A1, acc_U_A2 - def m_step_u(self, acc_U_A1, acc_U_A2): + def m_step_u(self, acc_U_A1_acc_U_A2_list): """ `U` Matrix M-step. This updates the `U` matrix @@ -1762,6 +1843,11 @@ class JFAMachine(FactorAnalysisBase): Accumulated statistics for V_A2(n_gaussians* feature_dimension, r_V) """ + acc_U_A1 = [acc[0] for acc in acc_U_A1_acc_U_A2_list] + acc_U_A2 = [acc[1] for acc in acc_U_A1_acc_U_A2_list] + + acc_U_A1 = reduce_iadd(acc_U_A1) + acc_U_A2 = reduce_iadd(acc_U_A2) self.update_U(acc_U_A1, acc_U_A2) @@ -1769,7 +1855,7 @@ class JFAMachine(FactorAnalysisBase): self, X, y, - n_classes, + n_samples_per_class, latent_y, ): """ @@ -1797,20 +1883,24 @@ class JFAMachine(FactorAnalysisBase): """ UProd = self._compute_uprod() - latent_x, _, _ = self.initialize_XYZ(y, n_classes=n_classes) + latent_x, _, _ = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) - latent_x = self.update_x( + n_classes = len(n_samples_per_class) + latent_x = self.compute_latent_x( X=X, y=y, n_classes=n_classes, UProd=UProd, - latent_x=latent_x, latent_y=latent_y, ) return latent_x - def e_step_d(self, X, y, n_classes, latent_x, latent_y, n_acc, f_acc): + def e_step_d( + self, X, y, n_samples_per_class, latent_x, latent_y, n_acc, f_acc + ): """ ISV E-step for the U matrix. @@ -1852,8 +1942,9 @@ class JFAMachine(FactorAnalysisBase): Accumulated statistics for D_A2(n_gaussians* feature_dimension, ) """ - - _, _, latent_z = self.initialize_XYZ(y, n_classes=n_classes) + _, _, latent_z = self.initialize_XYZ( + n_samples_per_class=n_samples_per_class + ) latent_z = self.update_z( X, @@ -1871,7 +1962,7 @@ class JFAMachine(FactorAnalysisBase): return acc_D_A1, acc_D_A2 - def m_step_d(self, acc_D_A1, acc_D_A2): + def m_step_d(self, acc_D_A1_acc_D_A2_list): """ `D` Matrix M-step. This updates the `D` matrix @@ -1886,6 +1977,12 @@ class JFAMachine(FactorAnalysisBase): Accumulated statistics for D_A2(n_gaussians* feature_dimension, ) """ + acc_D_A1 = [acc[0] for acc in acc_D_A1_acc_D_A2_list] + acc_D_A2 = [acc[1] for acc in acc_D_A1_acc_D_A2_list] + + acc_D_A1 = reduce_iadd(acc_D_A1) + acc_D_A2 = reduce_iadd(acc_D_A2) + self._D = acc_D_A2 / acc_D_A1 def enroll_using_stats(self, X, iterations=1): @@ -1915,7 +2012,9 @@ class JFAMachine(FactorAnalysisBase): UProd = self._compute_uprod() VProd = self._compute_vprod() - latent_x, latent_y, latent_z = self.initialize_XYZ(y, n_classes=1) + latent_x, latent_y, latent_z = self.initialize_XYZ( + n_samples_per_class=[len(X)] + ) for i in range(iterations): logger.info("Enrollment: Iteration %d", i + 1) @@ -1930,12 +2029,11 @@ class JFAMachine(FactorAnalysisBase): n_acc=n_acc, f_acc=f_acc, ) - latent_x = self.update_x( + latent_x = self.compute_latent_x( X=X, y=y, n_classes=1, UProd=UProd, - latent_x=latent_x, latent_y=latent_y, latent_z=latent_z, ) @@ -1998,58 +2096,104 @@ class JFAMachine(FactorAnalysisBase): ): self.create_UVD() - y = np.asarray(y) - n_classes = self.estimate_number_of_classes(y) + ( + input_is_dask, + n_classes, + n_samples_per_class, + ) = check_dask_input_samples_per_class(X, y) - # TODO: Point of MAP-REDUCE n_acc, f_acc = self.initialize_using_stats(X, y, n_classes=n_classes) # Updating V for i in range(self.em_iterations): logger.info("V Training: Iteration %d", i + 1) - # TODO: Point of MAP-REDUCE - acc_V_A1, acc_V_A2 = self.e_step_v( - X=X, - y=y, - n_classes=n_classes, - n_acc=n_acc, - f_acc=f_acc, - ) - self.m_step_v(acc_V_A1, acc_V_A2) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step_v)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step_v)(e_step_output) + dask.compute(delayed_em_step) + else: + e_step_output = self.e_step_v( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, + ) + self.m_step_v([e_step_output]) latent_y = self.finalize_v( - X=X, y=y, n_classes=n_classes, n_acc=n_acc, f_acc=f_acc + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + n_acc=n_acc, + f_acc=f_acc, ) # Updating U for i in range(self.em_iterations): logger.info("U Training: Iteration %d", i + 1) - # TODO: Point of MAP-REDUCE - acc_U_A1, acc_U_A2 = self.e_step_u( - X=X, - y=y, - n_classes=n_classes, - latent_y=latent_y, - ) - self.m_step_u(acc_U_A1, acc_U_A2) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step_u)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + latent_y=latent_y, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step_u)(e_step_output) + dask.compute(delayed_em_step) + else: + e_step_output = self.e_step_u( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + latent_y=latent_y, + ) + self.m_step_u([e_step_output]) latent_x = self.finalize_u( - X=X, y=y, n_classes=n_classes, latent_y=latent_y + X=X, y=y, n_samples_per_class=n_samples_per_class, latent_y=latent_y ) # Updating D for i in range(self.em_iterations): logger.info("D Training: Iteration %d", i + 1) - # TODO: Point of MAP-REDUCE - acc_D_A1, acc_D_A2 = self.e_step_d( - X=X, - y=y, - n_classes=n_classes, - latent_x=latent_x, - latent_y=latent_y, - n_acc=n_acc, - f_acc=f_acc, - ) - self.m_step_d(acc_D_A1, acc_D_A2) + if input_is_dask: + e_step_output = [ + dask.delayed(self.e_step_d)( + X=xx, + y=yy, + n_samples_per_class=n_samples_per_class, + latent_x=latent_x, + latent_y=latent_y, + n_acc=n_acc, + f_acc=f_acc, + ) + for xx, yy in zip(X, y) + ] + delayed_em_step = dask.delayed(self.m_step_d)(e_step_output) + dask.compute(delayed_em_step) + else: + e_step_output = self.e_step_d( + X=X, + y=y, + n_samples_per_class=n_samples_per_class, + latent_x=latent_x, + latent_y=latent_y, + n_acc=n_acc, + f_acc=f_acc, + ) + self.m_step_d([e_step_output]) return self diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 394a930..25159f1 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -854,7 +854,7 @@ class GMMMachine(BaseEstimator): """Returns the statistics for `X`.""" return e_step(data=X, machine=self) - def transform_one_by_one(self, X): + def stats_per_sample(self, X): return [e_step(data=xx, machine=self) for xx in X] def _more_tags(self): diff --git a/bob/learn/em/test/test_factor_analysis.py b/bob/learn/em/test/test_factor_analysis.py index 08ae5ab..8d82d0c 100644 --- a/bob/learn/em/test/test_factor_analysis.py +++ b/bob/learn/em/test/test_factor_analysis.py @@ -555,13 +555,14 @@ def test_ISV_JFA_fit(): machine = JFAMachine(2, 2, **machine_kwargs) test_attr = "V" - with multiprocess_dask_client(): - machine.fit(data, labels) + err_msg = f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}" + # with multiprocess_dask_client(): + machine.fit(data, labels) arr = getattr(machine, test_attr) np.testing.assert_allclose( arr, ref, atol=1e-7, - err_msg=f"Test failed with prior={prior} and machine_type={machine_type} and transform={transform}", + err_msg=err_msg, ) diff --git a/bob/learn/em/utils.py b/bob/learn/em/utils.py index 108e797..be92e8d 100644 --- a/bob/learn/em/utils.py +++ b/bob/learn/em/utils.py @@ -7,11 +7,12 @@ import numpy as np logger = logging.getLogger(__name__) -def check_and_persist_dask_input(data): +def check_and_persist_dask_input(data, persist=True): # check if input is a dask array. If so, persist and rebalance data input_is_dask = False if isinstance(data, da.Array): - data: da.Array = data.persist() + if persist: + data: da.Array = data.persist() input_is_dask = True # if there is a dask distributed client, rebalance data try: -- GitLab