From ed1032dc1336443ee95fa4eb582b3239659da198 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 22 Nov 2021 18:06:34 +0100
Subject: [PATCH] Use a custom dask wrapper

---
 bob/bio/gmm/bioalgorithm/GMM.py | 133 ++++++++++++++++++++------------
 1 file changed, 82 insertions(+), 51 deletions(-)

diff --git a/bob/bio/gmm/bioalgorithm/GMM.py b/bob/bio/gmm/bioalgorithm/GMM.py
index 05110e7..8ba494c 100644
--- a/bob/bio/gmm/bioalgorithm/GMM.py
+++ b/bob/bio/gmm/bioalgorithm/GMM.py
@@ -16,6 +16,7 @@ from typing import Callable
 
 import dask.array as da
 import numpy as np
+import dask
 
 from sklearn.base import BaseEstimator
 
@@ -26,6 +27,7 @@ from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgori
 from bob.learn.em.mixture import GMMMachine
 from bob.learn.em.mixture import GMMStats
 from bob.learn.em.mixture import linear_scoring
+from bob.pipelines.wrappers import DaskWrapper
 
 logger = logging.getLogger(__name__)
 
@@ -116,7 +118,12 @@ class GMM(BioAlgorithm, BaseEstimator):
         self.init_seed = init_seed
         self.rng = self.init_seed  # TODO verify if rng object needed
         self.responsibility_threshold = responsibility_threshold
-        self.scoring_function = scoring_function
+
+        def scoring_function_wrapped(*args, **kwargs):
+            with dask.config.set(scheduler="threads"):
+                return scoring_function(*args, **kwargs).compute()
+
+        self.scoring_function = scoring_function_wrapped
 
         self.ubm = None
 
@@ -160,8 +167,10 @@ class GMM(BioAlgorithm, BaseEstimator):
         self._check_feature(array)
         logger.debug(" .... Projecting %d feature vectors", array.shape[0])
         # Accumulates statistics
-        gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1])
-        self.ubm.acc_statistics(array, gmm_stats)
+        with dask.config.set(scheduler="threads"):
+            gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1])
+            self.ubm.acc_statistics(array, gmm_stats)
+            gmm_stats.compute()
 
         # return the resulting statistics
         return gmm_stats
@@ -182,19 +191,24 @@ class GMM(BioAlgorithm, BaseEstimator):
         logger.debug(" .... Enrolling with %d feature vectors", array.shape[0])
 
         # TODO responsibility_threshold
-        gmm = GMMMachine(
-            n_gaussians=self.number_of_gaussians,
-            trainer="map",
-            ubm=self.ubm,
-            convergence_threshold=self.training_threshold,
-            max_fitting_steps=self.gmm_enroll_iterations,
-            random_state=self.rng,
-            update_means=True,
-            update_variances=True,  # TODO default?
-            update_weights=True,  # TODO default?
-        )
-        gmm.variance_thresholds = self.variance_threshold
-        gmm = gmm.fit(array)
+        with dask.config.set(scheduler="threads"):
+            gmm = GMMMachine(
+                n_gaussians=self.number_of_gaussians,
+                trainer="map",
+                ubm=self.ubm,
+                convergence_threshold=self.training_threshold,
+                max_fitting_steps=self.gmm_enroll_iterations,
+                random_state=self.rng,
+                update_means=True,
+                update_variances=True,  # TODO default?
+                update_weights=True,  # TODO default?
+            )
+            gmm.variance_thresholds = self.variance_threshold
+            gmm = gmm.fit(array)
+            # info = {k: type(v) for k, v in gmm.__dict__.items()}
+            # for k, v in gmm.gaussians_.__dict__.items():
+            #     info[k] = type(v)
+            # raise ValueError(str(info))
         return gmm
 
     def read_model(self, model_file):
@@ -274,41 +288,8 @@ class GMM(BioAlgorithm, BaseEstimator):
     def fit(self, X, y=None, **kwargs):
         """Trains the UBM."""
         # TODO: Delayed to dask array
-
-        # def delayed_to_xr_dataset(delayed, meta=None):
-        #     """Converts one dask.delayed object to a dask.array"""
-        #     if meta is None:
-        #         meta = np.array(delayed.data.compute())
-        #         print(meta.shape)
-
-        #     darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
-        #     return darray, meta
-
-        # def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
-        #     output = []
-        #     for ds in delayed_samples:
-        #         d_array, meta = delayed_to_xr_dataset(ds, meta)
-        #         output.append(d_array)
-        #     return output, meta
-
-        # def delayeds_to_xr_dataset(delayeds, meta=None):
-        #     """Converts a set of dask.delayed to a list of dask.array"""
-        #     output = []
-        #     for d in delayeds:
-        #         d_array, meta = delayed_samples_to_dask_arrays(d, meta)
-        #         output.extend(d_array)
-        #     return output
-
-        # import ipdb; ipdb.set_trace()
-
-        # bags = ToDaskBag(npartitions=10).transform(X)
-
-        # delayeds = bags.to_delayed()
-        # lengths = bags.map_partitions(lambda samples: [len(samples)]).compute()
-        # for l, d in zip(lengths, delayeds):
-        #     d._length = l
-        # array_data = da.from_delayed(delayeds, shape=(2,-1,60))
-        # array_data = da.stack(delayeds_to_xr_dataset(delayeds))
+        if not all(isinstance(x, da.Array) for x in X):
+            raise ValueError(f"This function only supports dask arrays, {type(X[0])}")
 
         # Stack all the samples in a 2D array of features
         array = da.vstack(X)
@@ -343,3 +324,53 @@ class GMM(BioAlgorithm, BaseEstimator):
         # extracted data directly).
         # `project` is applied in the score function directly.
         return X
+
+
+
+
+def delayed_to_da(delayed, meta=None):
+    """Converts one dask.delayed object to a dask.array"""
+    if meta is None:
+        meta = np.array(delayed.data.compute())
+        print(meta.shape)
+
+    darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
+    return darray, meta
+
+
+def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
+    output = []
+    for ds in delayed_samples:
+        d_array, meta = delayed_to_da(ds, meta)
+        output.append(d_array)
+    return output, meta
+
+
+def delayeds_to_dask_array(delayeds, meta=None):
+    """Converts a set of dask.delayed to a list of dask.array"""
+    output = []
+    for d in delayeds:
+        d_array, meta = delayed_samples_to_dask_arrays(d, meta)
+        output.extend(d_array)
+    return output
+
+
+class GMMDaskWrapper(DaskWrapper):
+    def fit(self, X, y=None, **fit_params):
+        # convert X which is a dask bag to a dask array
+        X = X.persist()
+        delayeds = X.to_delayed()
+        lengths = X.map_partitions(lambda samples: [len(samples)]).compute()
+        shapes = X.map_partitions(
+            lambda samples: [[s.data.shape for s in samples]]
+        ).compute()
+        dtype, X = None, []
+        for l, s, d in zip(lengths, shapes, delayeds):
+            d._length = l
+            for shape, ds in zip(s, d):
+                if dtype is None:
+                    dtype = np.array(ds.data.compute()).dtype
+                darray = da.from_delayed(ds.data, shape, dtype=dtype, name=False)
+                X.append(darray)
+        self.estimator.fit(X, y, **fit_params)
+        return self
-- 
GitLab