diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index 3fb0d2c138b9f5bebb3dfd5965ed849d4803117e..9f39171fde81725cebd02e7f0494defef2818ce2 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -16,7 +16,6 @@ import logging from typing import Callable from typing import Union -import dask import dask.array as da import numpy as np @@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator): scoring_function: Callable = linear_scoring, # RNG init_seed: int = 5489, + **kwargs, ): """Initializes the local UBM-GMM tool chain. @@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator): self.ubm = None - super().__init__() + super().__init__(**kwargs) def _check_feature(self, feature): """Checks that the features are appropriate""" @@ -196,27 +196,28 @@ class GMM(BioAlgorithm, BaseEstimator): Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data. """ - [self._check_feature(feature) for feature in data] - array = da.vstack(data) + for feature in data: + self._check_feature(feature) + + data = np.vstack(data) # Use the array to train a GMM and return it - logger.info("Enrolling with %d feature vectors", array.shape[0]) + logger.info("Enrolling with %d feature vectors", data.shape[0]) - with dask.config.set(scheduler="threads"): - gmm = GMMMachine( - n_gaussians=self.number_of_gaussians, - trainer="map", - ubm=copy.deepcopy(self.ubm), - convergence_threshold=self.training_threshold, - max_fitting_steps=self.gmm_enroll_iterations, - random_state=self.rng, - update_means=self.enroll_update_means, - update_variances=self.enroll_update_variances, - update_weights=self.enroll_update_weights, - mean_var_update_threshold=self.variance_threshold, - map_relevance_factor=self.enroll_relevance_factor, - map_alpha=self.enroll_alpha, - ) - gmm.fit(array) + gmm = GMMMachine( + n_gaussians=self.number_of_gaussians, + trainer="map", + ubm=copy.deepcopy(self.ubm), + convergence_threshold=self.training_threshold, + max_fitting_steps=self.gmm_enroll_iterations, + random_state=self.rng, + update_means=self.enroll_update_means, + update_variances=self.enroll_update_variances, + update_weights=self.enroll_update_weights, + mean_var_update_threshold=self.variance_threshold, + map_relevance_factor=self.enroll_relevance_factor, + map_alpha=self.enroll_alpha, + ) + gmm.fit(data) return gmm def read_biometric_reference(self, model_file): @@ -277,12 +278,13 @@ class GMM(BioAlgorithm, BaseEstimator): frame_length_normalization=True, ) - def fit(self, X, y=None, **kwargs): + def fit(self, array, y=None, **kwargs): """Trains the UBM.""" # Stack all the samples in a 2D array of features - array = da.vstack(X).persist() + if isinstance(array, da.Array): + array = array.persist() - logger.debug("UBM with %d feature vectors", array.shape[0]) + logger.debug("UBM with %d feature vectors", len(array)) logger.debug(f"Creating UBM machine with {self.number_of_gaussians} gaussians") @@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator): # Train the GMM logger.info("Training UBM GMM") - self.ubm.fit(array, ubm_train=True) + self.ubm.fit(array) return self diff --git a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 index 94190e2b4c1cfb7bdf36627b5fd770fb9079ded0..478955e623c0985d70c89d04ed2fcb5481de0db2 100644 Binary files a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 and b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5 index 894d3f7e6594fab0c9720dd71b7f5082f338ad4e..8a217ae273fc50374f401f442457a2d2971a3035 100644 Binary files a/bob/bio/gmm/test/data/gmm_projected.hdf5 and b/bob/bio/gmm/test/data/gmm_projected.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_ubm.hdf5 b/bob/bio/gmm/test/data/gmm_ubm.hdf5 index 6f40b3b024d24a0349f122864a1b3c254b81955f..63ce75bd0e44ac19105079d4a831e435460a450e 100644 Binary files a/bob/bio/gmm/test/data/gmm_ubm.hdf5 and b/bob/bio/gmm/test/data/gmm_ubm.hdf5 differ diff --git a/bob/bio/gmm/test/test_gmm.py b/bob/bio/gmm/test/test_gmm.py index 809d8d3626eaf21870211c93b39b1a54c7f8c049..2cfd519af824e6c77e5620f22edf94af67e67531 100644 --- a/bob/bio/gmm/test/test_gmm.py +++ b/bob/bio/gmm/test/test_gmm.py @@ -64,6 +64,7 @@ def test_training(): train_data = utils.random_training_set( (100, 45), count=5, minimum=-5.0, maximum=5.0 ) + train_data = numpy.vstack(train_data) # Train the UBM (projector) gmm1.fit(train_data) @@ -155,7 +156,7 @@ def test_score(): ) probe_data = utils.random_array((20, 45), -5.0, 5.0, seed=seed_value) - reference_score = 0.601025 + reference_score = 0.6509 numpy.testing.assert_almost_equal( gmm1.score(biometric_reference, probe), reference_score, decimal=5