From ed1032dc1336443ee95fa4eb582b3239659da198 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Mon, 22 Nov 2021 18:06:34 +0100 Subject: [PATCH] Use a custom dask wrapper --- bob/bio/gmm/bioalgorithm/GMM.py | 133 ++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 51 deletions(-) diff --git a/bob/bio/gmm/bioalgorithm/GMM.py b/bob/bio/gmm/bioalgorithm/GMM.py index 05110e7..8ba494c 100644 --- a/bob/bio/gmm/bioalgorithm/GMM.py +++ b/bob/bio/gmm/bioalgorithm/GMM.py @@ -16,6 +16,7 @@ from typing import Callable import dask.array as da import numpy as np +import dask from sklearn.base import BaseEstimator @@ -26,6 +27,7 @@ from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgori from bob.learn.em.mixture import GMMMachine from bob.learn.em.mixture import GMMStats from bob.learn.em.mixture import linear_scoring +from bob.pipelines.wrappers import DaskWrapper logger = logging.getLogger(__name__) @@ -116,7 +118,12 @@ class GMM(BioAlgorithm, BaseEstimator): self.init_seed = init_seed self.rng = self.init_seed # TODO verify if rng object needed self.responsibility_threshold = responsibility_threshold - self.scoring_function = scoring_function + + def scoring_function_wrapped(*args, **kwargs): + with dask.config.set(scheduler="threads"): + return scoring_function(*args, **kwargs).compute() + + self.scoring_function = scoring_function_wrapped self.ubm = None @@ -160,8 +167,10 @@ class GMM(BioAlgorithm, BaseEstimator): self._check_feature(array) logger.debug(" .... Projecting %d feature vectors", array.shape[0]) # Accumulates statistics - gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1]) - self.ubm.acc_statistics(array, gmm_stats) + with dask.config.set(scheduler="threads"): + gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1]) + self.ubm.acc_statistics(array, gmm_stats) + gmm_stats.compute() # return the resulting statistics return gmm_stats @@ -182,19 +191,24 @@ class GMM(BioAlgorithm, BaseEstimator): logger.debug(" .... Enrolling with %d feature vectors", array.shape[0]) # TODO responsibility_threshold - gmm = GMMMachine( - n_gaussians=self.number_of_gaussians, - trainer="map", - ubm=self.ubm, - convergence_threshold=self.training_threshold, - max_fitting_steps=self.gmm_enroll_iterations, - random_state=self.rng, - update_means=True, - update_variances=True, # TODO default? - update_weights=True, # TODO default? - ) - gmm.variance_thresholds = self.variance_threshold - gmm = gmm.fit(array) + with dask.config.set(scheduler="threads"): + gmm = GMMMachine( + n_gaussians=self.number_of_gaussians, + trainer="map", + ubm=self.ubm, + convergence_threshold=self.training_threshold, + max_fitting_steps=self.gmm_enroll_iterations, + random_state=self.rng, + update_means=True, + update_variances=True, # TODO default? + update_weights=True, # TODO default? + ) + gmm.variance_thresholds = self.variance_threshold + gmm = gmm.fit(array) + # info = {k: type(v) for k, v in gmm.__dict__.items()} + # for k, v in gmm.gaussians_.__dict__.items(): + # info[k] = type(v) + # raise ValueError(str(info)) return gmm def read_model(self, model_file): @@ -274,41 +288,8 @@ class GMM(BioAlgorithm, BaseEstimator): def fit(self, X, y=None, **kwargs): """Trains the UBM.""" # TODO: Delayed to dask array - - # def delayed_to_xr_dataset(delayed, meta=None): - # """Converts one dask.delayed object to a dask.array""" - # if meta is None: - # meta = np.array(delayed.data.compute()) - # print(meta.shape) - - # darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False) - # return darray, meta - - # def delayed_samples_to_dask_arrays(delayed_samples, meta=None): - # output = [] - # for ds in delayed_samples: - # d_array, meta = delayed_to_xr_dataset(ds, meta) - # output.append(d_array) - # return output, meta - - # def delayeds_to_xr_dataset(delayeds, meta=None): - # """Converts a set of dask.delayed to a list of dask.array""" - # output = [] - # for d in delayeds: - # d_array, meta = delayed_samples_to_dask_arrays(d, meta) - # output.extend(d_array) - # return output - - # import ipdb; ipdb.set_trace() - - # bags = ToDaskBag(npartitions=10).transform(X) - - # delayeds = bags.to_delayed() - # lengths = bags.map_partitions(lambda samples: [len(samples)]).compute() - # for l, d in zip(lengths, delayeds): - # d._length = l - # array_data = da.from_delayed(delayeds, shape=(2,-1,60)) - # array_data = da.stack(delayeds_to_xr_dataset(delayeds)) + if not all(isinstance(x, da.Array) for x in X): + raise ValueError(f"This function only supports dask arrays, {type(X[0])}") # Stack all the samples in a 2D array of features array = da.vstack(X) @@ -343,3 +324,53 @@ class GMM(BioAlgorithm, BaseEstimator): # extracted data directly). # `project` is applied in the score function directly. return X + + + + +def delayed_to_da(delayed, meta=None): + """Converts one dask.delayed object to a dask.array""" + if meta is None: + meta = np.array(delayed.data.compute()) + print(meta.shape) + + darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False) + return darray, meta + + +def delayed_samples_to_dask_arrays(delayed_samples, meta=None): + output = [] + for ds in delayed_samples: + d_array, meta = delayed_to_da(ds, meta) + output.append(d_array) + return output, meta + + +def delayeds_to_dask_array(delayeds, meta=None): + """Converts a set of dask.delayed to a list of dask.array""" + output = [] + for d in delayeds: + d_array, meta = delayed_samples_to_dask_arrays(d, meta) + output.extend(d_array) + return output + + +class GMMDaskWrapper(DaskWrapper): + def fit(self, X, y=None, **fit_params): + # convert X which is a dask bag to a dask array + X = X.persist() + delayeds = X.to_delayed() + lengths = X.map_partitions(lambda samples: [len(samples)]).compute() + shapes = X.map_partitions( + lambda samples: [[s.data.shape for s in samples]] + ).compute() + dtype, X = None, [] + for l, s, d in zip(lengths, shapes, delayeds): + d._length = l + for shape, ds in zip(s, d): + if dtype is None: + dtype = np.array(ds.data.compute()).dtype + darray = da.from_delayed(ds.data, shape, dtype=dtype, name=False) + X.append(darray) + self.estimator.fit(X, y, **fit_params) + return self -- GitLab