Skip to content
Snippets Groups Projects
Commit 78b62b40 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Remove projection in score, fix ubm not set in fit

parent d079cddf
Branches
No related tags found
1 merge request!26Python implementation of GMM
Pipeline #56789 failed
...@@ -21,8 +21,6 @@ from h5py import File as HDF5File ...@@ -21,8 +21,6 @@ from h5py import File as HDF5File
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
import bob.core
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.learn.em.mixture import GMMMachine from bob.learn.em.mixture import GMMMachine
from bob.learn.em.mixture import GMMStats from bob.learn.em.mixture import GMMStats
...@@ -121,7 +119,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -121,7 +119,7 @@ class GMM(BioAlgorithm, BaseEstimator):
def scoring_function_wrapped(*args, **kwargs): def scoring_function_wrapped(*args, **kwargs):
with dask.config.set(scheduler="threads"): with dask.config.set(scheduler="threads"):
return scoring_function(*args, **kwargs).compute() return scoring_function(*args, **kwargs)
self.scoring_function = scoring_function_wrapped self.scoring_function = scoring_function_wrapped
...@@ -233,11 +231,10 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -233,11 +231,10 @@ class GMM(BioAlgorithm, BaseEstimator):
""" """
assert isinstance(biometric_reference, GMMMachine) assert isinstance(biometric_reference, GMMMachine)
stats = self.project(data)
return self.scoring_function( return self.scoring_function(
models_means=[biometric_reference], models_means=[biometric_reference],
ubm=self.ubm, ubm=self.ubm,
test_stats=stats, test_stats=data,
frame_length_normalization=True, frame_length_normalization=True,
)[0, 0] )[0, 0]
...@@ -309,7 +306,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -309,7 +306,7 @@ class GMM(BioAlgorithm, BaseEstimator):
logger.info("Training UBM GMM") logger.info("Training UBM GMM")
# Resetting the pseudo random number generator so we can have the same initialization for serial and parallel execution. # Resetting the pseudo random number generator so we can have the same initialization for serial and parallel execution.
# self.rng = bob.core.random.mt19937(self.init_seed) # self.rng = bob.core.random.mt19937(self.init_seed)
self.ubm.fit(array) self.ubm = self.ubm.fit(array)
return self return self
...@@ -321,51 +318,5 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -321,51 +318,5 @@ class GMM(BioAlgorithm, BaseEstimator):
# `project` is applied in the score function directly. # `project` is applied in the score function directly.
return X return X
def _more_tags(self):
return {"bob_fit_supports_dask_array": True}
def delayed_to_da(delayed, meta=None):
"""Converts one dask.delayed object to a dask.array"""
if meta is None:
meta = np.array(delayed.data.compute())
darray = da.from_delayed(delayed.data, meta.shape, dtype=meta.dtype, name=False)
return darray, meta
def delayed_samples_to_dask_arrays(delayed_samples, meta=None):
output = []
for ds in delayed_samples:
d_array, meta = delayed_to_da(ds, meta)
output.append(d_array)
return output, meta
def delayeds_to_dask_array(delayeds, meta=None):
"""Converts a set of dask.delayed to a list of dask.array"""
output = []
for d in delayeds:
d_array, meta = delayed_samples_to_dask_arrays(d, meta)
output.extend(d_array)
return output
class GMMDaskWrapper(DaskWrapper):
def fit(self, X, y=None, **fit_params):
# convert X which is a dask bag to a dask array
X = X.persist()
delayeds = X.to_delayed()
lengths = X.map_partitions(lambda samples: [len(samples)]).compute()
shapes = X.map_partitions(
lambda samples: [[s.data.shape for s in samples]]
).compute()
dtype, X = None, []
for l, s, d in zip(lengths, shapes, delayeds):
d._length = l
for shape, ds in zip(s, d):
if dtype is None:
dtype = np.array(ds.data.compute()).dtype
darray = da.from_delayed(ds.data, shape, dtype=dtype, name=False)
X.append(darray)
self.estimator.fit(X, y, **fit_params)
return self
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment