From 78b62b408bb4d4a232b604bcc957e9663b69e165 Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Mon, 6 Dec 2021 12:56:31 +0100
Subject: [PATCH] Remove projection in score, fix ubm not set in fit

---
 bob/bio/gmm/bioalgorithm/GMM.py | 59 +++------------------------------
 1 file changed, 5 insertions(+), 54 deletions(-)

diff --git a/bob/bio/gmm/bioalgorithm/GMM.py b/bob/bio/gmm/bioalgorithm/GMM.py
index 0247428..ed574e2 100644
--- a/bob/bio/gmm/bioalgorithm/GMM.py
+++ b/bob/bio/gmm/bioalgorithm/GMM.py
@@ -21,8 +21,6 @@ from h5py import File as HDF5File
 
 from sklearn.base import BaseEstimator
 
-import bob.core
-
 from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
 from bob.learn.em.mixture import GMMMachine
 from bob.learn.em.mixture import GMMStats
@@ -121,7 +119,7 @@ class GMM(BioAlgorithm, BaseEstimator):
 
         def scoring_function_wrapped(*args, **kwargs):
             with dask.config.set(scheduler="threads"):
-                return scoring_function(*args, **kwargs).compute()
+                return scoring_function(*args, **kwargs)
 
         self.scoring_function = scoring_function_wrapped
 
@@ -233,11 +231,10 @@ class GMM(BioAlgorithm, BaseEstimator):
         """
 
         assert isinstance(biometric_reference, GMMMachine)
-        stats = self.project(data)
         return self.scoring_function(
             models_means=[biometric_reference],
             ubm=self.ubm,
-            test_stats=stats,
+            test_stats=data,
             frame_length_normalization=True,
         )[0, 0]
 
@@ -309,7 +306,7 @@ class GMM(BioAlgorithm, BaseEstimator):
         logger.info("Training UBM GMM")
         # Resetting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
         # self.rng = bob.core.random.mt19937(self.init_seed)
-        self.ubm.fit(array)
+        self.ubm = self.ubm.fit(array)
 
         return self
 
@@ -321,51 +318,5 @@ class GMM(BioAlgorithm, BaseEstimator):
         # `project` is applied in the score function directly.
         return X
 
-
-
-
-def delayed_to_da(delayed, meta=None):
-    """Converts one dask.delayed object to a dask.array"""
-    if meta is None:
-        meta = np.array(delayed.data.compute())
-
-    darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
-    return darray, meta
-
-
-def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
-    output = []
-    for ds in delayed_samples:
-        d_array, meta = delayed_to_da(ds, meta)
-        output.append(d_array)
-    return output, meta
-
-
-def delayeds_to_dask_array(delayeds, meta=None):
-    """Converts a set of dask.delayed to a list of dask.array"""
-    output = []
-    for d in delayeds:
-        d_array, meta = delayed_samples_to_dask_arrays(d, meta)
-        output.extend(d_array)
-    return output
-
-
-class GMMDaskWrapper(DaskWrapper):
-    def fit(self, X, y=None, **fit_params):
-        # convert X which is a dask bag to a dask array
-        X = X.persist()
-        delayeds = X.to_delayed()
-        lengths = X.map_partitions(lambda samples: [len(samples)]).compute()
-        shapes = X.map_partitions(
-            lambda samples: [[s.data.shape for s in samples]]
-        ).compute()
-        dtype, X = None, []
-        for l, s, d in zip(lengths, shapes, delayeds):
-            d._length = l
-            for shape, ds in zip(s, d):
-                if dtype is None:
-                    dtype = np.array(ds.data.compute()).dtype
-                darray = da.from_delayed(ds.data, shape, dtype=dtype, name=False)
-                X.append(darray)
-        self.estimator.fit(X, y, **fit_params)
-        return self
+    def _more_tags(self):
+        return {"bob_fit_supports_dask_array": True}
-- 
GitLab