diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index 3fb0d2c138b9f5bebb3dfd5965ed849d4803117e..6ea1e4617ac79394a0cd0dc61535aba46609528d 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -16,7 +16,6 @@ import logging from typing import Callable from typing import Union -import dask import dask.array as da import numpy as np @@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator): scoring_function: Callable = linear_scoring, # RNG init_seed: int = 5489, + **kwargs, ): """Initializes the local UBM-GMM tool chain. @@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator): self.ubm = None - super().__init__() + super().__init__(**kwargs) def _check_feature(self, feature): """Checks that the features are appropriate""" @@ -196,27 +196,28 @@ class GMM(BioAlgorithm, BaseEstimator): Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data. """ - [self._check_feature(feature) for feature in data] - array = da.vstack(data) + for feature in data: + self._check_feature(feature) + + data = np.vstack(data) # Use the array to train a GMM and return it - logger.info("Enrolling with %d feature vectors", array.shape[0]) + logger.info("Enrolling with %d feature vectors", data.shape[0]) - with dask.config.set(scheduler="threads"): - gmm = GMMMachine( - n_gaussians=self.number_of_gaussians, - trainer="map", - ubm=copy.deepcopy(self.ubm), - convergence_threshold=self.training_threshold, - max_fitting_steps=self.gmm_enroll_iterations, - random_state=self.rng, - update_means=self.enroll_update_means, - update_variances=self.enroll_update_variances, - update_weights=self.enroll_update_weights, - mean_var_update_threshold=self.variance_threshold, - map_relevance_factor=self.enroll_relevance_factor, - map_alpha=self.enroll_alpha, - ) - gmm.fit(array) + gmm = GMMMachine( + n_gaussians=self.number_of_gaussians, + trainer="map", + ubm=copy.deepcopy(self.ubm), + convergence_threshold=self.training_threshold, + max_fitting_steps=self.gmm_enroll_iterations, + random_state=self.rng, + update_means=self.enroll_update_means, + update_variances=self.enroll_update_variances, + update_weights=self.enroll_update_weights, + mean_var_update_threshold=self.variance_threshold, + map_relevance_factor=self.enroll_relevance_factor, + map_alpha=self.enroll_alpha, + ) + gmm.fit(data) return gmm def read_biometric_reference(self, model_file): @@ -277,10 +278,11 @@ class GMM(BioAlgorithm, BaseEstimator): frame_length_normalization=True, ) - def fit(self, X, y=None, **kwargs): + def fit(self, array, y=None, **kwargs): """Trains the UBM.""" # Stack all the samples in a 2D array of features - array = da.vstack(X).persist() + if isinstance(array, da.Array): + array = array.persist() logger.debug("UBM with %d feature vectors", array.shape[0]) @@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator): # Train the GMM logger.info("Training UBM GMM") - self.ubm.fit(array, ubm_train=True) + self.ubm.fit(array) return self