Skip to content
Snippets Groups Projects
Commit ed1032dc authored by Amir MOHAMMADI's avatar Amir MOHAMMADI Committed by Yannick DAYER
Browse files

Use a custom dask wrapper

parent d779abda
No related branches found
No related tags found
1 merge request!26Python implementation of GMM
......@@ -16,6 +16,7 @@ from typing import Callable
import dask.array as da
import numpy as np
import dask
from sklearn.base import BaseEstimator
......@@ -26,6 +27,7 @@ from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgori
from bob.learn.em.mixture import GMMMachine
from bob.learn.em.mixture import GMMStats
from bob.learn.em.mixture import linear_scoring
from bob.pipelines.wrappers import DaskWrapper
logger = logging.getLogger(__name__)
......@@ -116,7 +118,12 @@ class GMM(BioAlgorithm, BaseEstimator):
self.init_seed = init_seed
self.rng = self.init_seed # TODO verify if rng object needed
self.responsibility_threshold = responsibility_threshold
self.scoring_function = scoring_function
def scoring_function_wrapped(*args, **kwargs):
with dask.config.set(scheduler="threads"):
return scoring_function(*args, **kwargs).compute()
self.scoring_function = scoring_function_wrapped
self.ubm = None
......@@ -160,8 +167,10 @@ class GMM(BioAlgorithm, BaseEstimator):
self._check_feature(array)
logger.debug(" .... Projecting %d feature vectors", array.shape[0])
# Accumulates statistics
gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1])
self.ubm.acc_statistics(array, gmm_stats)
with dask.config.set(scheduler="threads"):
gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1])
self.ubm.acc_statistics(array, gmm_stats)
gmm_stats.compute()
# return the resulting statistics
return gmm_stats
......@@ -182,19 +191,24 @@ class GMM(BioAlgorithm, BaseEstimator):
logger.debug(" .... Enrolling with %d feature vectors", array.shape[0])
# TODO responsibility_threshold
gmm = GMMMachine(
n_gaussians=self.number_of_gaussians,
trainer="map",
ubm=self.ubm,
convergence_threshold=self.training_threshold,
max_fitting_steps=self.gmm_enroll_iterations,
random_state=self.rng,
update_means=True,
update_variances=True, # TODO default?
update_weights=True, # TODO default?
)
gmm.variance_thresholds = self.variance_threshold
gmm = gmm.fit(array)
with dask.config.set(scheduler="threads"):
gmm = GMMMachine(
n_gaussians=self.number_of_gaussians,
trainer="map",
ubm=self.ubm,
convergence_threshold=self.training_threshold,
max_fitting_steps=self.gmm_enroll_iterations,
random_state=self.rng,
update_means=True,
update_variances=True, # TODO default?
update_weights=True, # TODO default?
)
gmm.variance_thresholds = self.variance_threshold
gmm = gmm.fit(array)
# info = {k: type(v) for k, v in gmm.__dict__.items()}
# for k, v in gmm.gaussians_.__dict__.items():
# info[k] = type(v)
# raise ValueError(str(info))
return gmm
def read_model(self, model_file):
......@@ -274,41 +288,8 @@ class GMM(BioAlgorithm, BaseEstimator):
def fit(self, X, y=None, **kwargs):
"""Trains the UBM."""
# TODO: Delayed to dask array
# def delayed_to_xr_dataset(delayed, meta=None):
# """Converts one dask.delayed object to a dask.array"""
# if meta is None:
# meta = np.array(delayed.data.compute())
# print(meta.shape)
# darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
# return darray, meta
# def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
# output = []
# for ds in delayed_samples:
# d_array, meta = delayed_to_xr_dataset(ds, meta)
# output.append(d_array)
# return output, meta
# def delayeds_to_xr_dataset(delayeds, meta=None):
# """Converts a set of dask.delayed to a list of dask.array"""
# output = []
# for d in delayeds:
# d_array, meta = delayed_samples_to_dask_arrays(d, meta)
# output.extend(d_array)
# return output
# import ipdb; ipdb.set_trace()
# bags = ToDaskBag(npartitions=10).transform(X)
# delayeds = bags.to_delayed()
# lengths = bags.map_partitions(lambda samples: [len(samples)]).compute()
# for l, d in zip(lengths, delayeds):
# d._length = l
# array_data = da.from_delayed(delayeds, shape=(2,-1,60))
# array_data = da.stack(delayeds_to_xr_dataset(delayeds))
if not all(isinstance(x, da.Array) for x in X):
raise ValueError(f"This function only supports dask arrays, {type(X[0])}")
# Stack all the samples in a 2D array of features
array = da.vstack(X)
......@@ -343,3 +324,53 @@ class GMM(BioAlgorithm, BaseEstimator):
# extracted data directly).
# `project` is applied in the score function directly.
return X
def delayed_to_da(delayed, meta=None):
"""Converts one dask.delayed object to a dask.array"""
if meta is None:
meta = np.array(delayed.data.compute())
print(meta.shape)
darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
return darray, meta
def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
output = []
for ds in delayed_samples:
d_array, meta = delayed_to_da(ds, meta)
output.append(d_array)
return output, meta
def delayeds_to_dask_array(delayeds, meta=None):
"""Converts a set of dask.delayed to a list of dask.array"""
output = []
for d in delayeds:
d_array, meta = delayed_samples_to_dask_arrays(d, meta)
output.extend(d_array)
return output
class GMMDaskWrapper(DaskWrapper):
def fit(self, X, y=None, **fit_params):
# convert X which is a dask bag to a dask array
X = X.persist()
delayeds = X.to_delayed()
lengths = X.map_partitions(lambda samples: [len(samples)]).compute()
shapes = X.map_partitions(
lambda samples: [[s.data.shape for s in samples]]
).compute()
dtype, X = None, []
for l, s, d in zip(lengths, shapes, delayeds):
d._length = l
for shape, ds in zip(s, d):
if dtype is None:
dtype = np.array(ds.data.compute()).dtype
darray = da.from_delayed(ds.data, shape, dtype=dtype, name=False)
X.append(darray)
self.estimator.fit(X, y, **fit_params)
return self
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment