Skip to content
Snippets Groups Projects
Commit 78b62b40 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Remove projection in score, fix ubm not set in fit

parent d079cddf
No related branches found
No related tags found
1 merge request!26Python implementation of GMM
Pipeline #56789 failed
......@@ -21,8 +21,6 @@ from h5py import File as HDF5File
from sklearn.base import BaseEstimator
import bob.core
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.learn.em.mixture import GMMMachine
from bob.learn.em.mixture import GMMStats
......@@ -121,7 +119,7 @@ class GMM(BioAlgorithm, BaseEstimator):
def scoring_function_wrapped(*args, **kwargs):
with dask.config.set(scheduler="threads"):
return scoring_function(*args, **kwargs).compute()
return scoring_function(*args, **kwargs)
self.scoring_function = scoring_function_wrapped
......@@ -233,11 +231,10 @@ class GMM(BioAlgorithm, BaseEstimator):
"""
assert isinstance(biometric_reference, GMMMachine)
stats = self.project(data)
return self.scoring_function(
models_means=[biometric_reference],
ubm=self.ubm,
test_stats=stats,
test_stats=data,
frame_length_normalization=True,
)[0, 0]
......@@ -309,7 +306,7 @@ class GMM(BioAlgorithm, BaseEstimator):
logger.info("Training UBM GMM")
# Resetting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
# self.rng = bob.core.random.mt19937(self.init_seed)
self.ubm.fit(array)
self.ubm = self.ubm.fit(array)
return self
......@@ -321,51 +318,5 @@ class GMM(BioAlgorithm, BaseEstimator):
# `project` is applied in the score function directly.
return X
def delayed_to_da(delayed, meta=None):
"""Converts one dask.delayed object to a dask.array"""
if meta is None:
meta = np.array(delayed.data.compute())
darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
return darray, meta
def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
output = []
for ds in delayed_samples:
d_array, meta = delayed_to_da(ds, meta)
output.append(d_array)
return output, meta
def delayeds_to_dask_array(delayeds, meta=None):
"""Converts a set of dask.delayed to a list of dask.array"""
output = []
for d in delayeds:
d_array, meta = delayed_samples_to_dask_arrays(d, meta)
output.extend(d_array)
return output
class GMMDaskWrapper(DaskWrapper):
def fit(self, X, y=None, **fit_params):
# convert X which is a dask bag to a dask array
X = X.persist()
delayeds = X.to_delayed()
lengths = X.map_partitions(lambda samples: [len(samples)]).compute()
shapes = X.map_partitions(
lambda samples: [[s.data.shape for s in samples]]
).compute()
dtype, X = None, []
for l, s, d in zip(lengths, shapes, delayeds):
d._length = l
for shape, ds in zip(s, d):
if dtype is None:
dtype = np.array(ds.data.compute()).dtype
darray = da.from_delayed(ds.data, shape, dtype=dtype, name=False)
X.append(darray)
self.estimator.fit(X, y, **fit_params)
return self
def _more_tags(self):
return {"bob_fit_supports_dask_array": True}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment