diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index a4a0450367f23d7daa2653d5512072b2ed76b12b..6bc5b0e53104acdee689c29579ae65e9d21da5dc 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -7,6 +7,7 @@ import logging import numpy +from dask.delayed import Delayed as DaskDelayed from dask_ml.cluster import KMeans as DistributedKMeans from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin @@ -16,6 +17,7 @@ import bob.io.base import bob.learn.em from bob.bio.base.algorithm import Algorithm +from bob.bio.gmm.algorithm.utils import delayeds_to_xr_dataset from bob.bio.gmm.algorithm.utils import get_variances_and_weights_of_clusters logger = logging.getLogger(__name__) @@ -152,7 +154,6 @@ class GMM(Algorithm): # Creates the machines (KMeans and GMM) logger.debug(" .... Creating machines") - # kmeans = bob.learn.em.KMeansMachine(self.number_of_gaussians, input_size) self.ubm = bob.learn.em.GMMMachine(self.number_of_gaussians, input_size) logger.info(" -> Training K-Means") @@ -167,8 +168,8 @@ class GMM(Algorithm): # Initializes the GMM self.ubm.means = self.kmeans_trainer.cluster_centers_ - self.ubm.variances = variances - self.ubm.weights = weights + self.ubm.variances = variances.compute() + self.ubm.weights = weights.compute() self.ubm.set_variance_thresholds(self.variance_threshold) # Trains the GMM @@ -198,6 +199,11 @@ class GMM(Algorithm): def train_projector(self, train_features, projector_file): """Computes the Universal Background Model from the training ("world") data""" + logger.debug(f"Training projector with type {type(train_features)}") + + if isinstance(train_features, DaskDelayed): + train_features = delayeds_to_xr_dataset(train_features) + [self._check_feature(feature) for feature in train_features] logger.info( diff --git a/bob/bio/gmm/algorithm/utils.py b/bob/bio/gmm/algorithm/utils.py index 013733daa855e436231ca3588735bb5c85460d87..719849da8957b298989d0bf346251075cb175f27 100644 --- a/bob/bio/gmm/algorithm/utils.py +++ b/bob/bio/gmm/algorithm/utils.py @@ -75,7 +75,4 @@ def get_variances_and_weights_of_clusters( means = means_sum / weights_count[:, None] variances = (variances_sum / weights_count[:, None]) - (means ** 2) - logger.debug( - f"get_variances_and_weights_of_clusters: var: {variances.compute()}, weights: {weights.compute()}" - ) return variances, weights diff --git a/bob/bio/gmm/test/data/gmm_model.hdf5 b/bob/bio/gmm/test/data/gmm_model.hdf5 index a57d494c0fc9112e582827d577ae4bf974d2e174..bdf68ad08c8be953bb6fb6d30375ed77957a7dc7 100644 Binary files a/bob/bio/gmm/test/data/gmm_model.hdf5 and b/bob/bio/gmm/test/data/gmm_model.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5 index 31d930b955098e3ae990c1e2509d2c232d1a86be..c6fdffdeb711a0d32c32bbc44d9a525704122139 100644 Binary files a/bob/bio/gmm/test/data/gmm_projected.hdf5 and b/bob/bio/gmm/test/data/gmm_projected.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_projector.hdf5 b/bob/bio/gmm/test/data/gmm_projector.hdf5 index 4c47be97a009e963d25301904a7420eced1b55e9..f40e38c5e81f7c3cc1a8fb8baeedeb28acf9c7ca 100644 Binary files a/bob/bio/gmm/test/data/gmm_projector.hdf5 and b/bob/bio/gmm/test/data/gmm_projector.hdf5 differ diff --git a/bob/bio/gmm/test/test_algorithms.py b/bob/bio/gmm/test/test_algorithms.py index 7cb0bb5b5c5052d936285fd1ee49cf510dbf6665..cdbbd6889407c46d74ad088808f29637da881a76 100644 --- a/bob/bio/gmm/test/test_algorithms.py +++ b/bob/bio/gmm/test/test_algorithms.py @@ -143,7 +143,7 @@ def test_gmm(): probe = gmm1.read_feature( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") ) - reference_score = -0.01676570 + reference_score = -0.06732984 assert ( abs(gmm1.score(model, probe) - reference_score) < 1e-5 ), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score) @@ -214,7 +214,7 @@ def test_gmm_regular(): probe = utils.random_array((20, 45), -5.0, 5.0, seed=84) # compare model with probe - reference_score = -0.40840148 + reference_score = -0.48658502 assert ( abs(gmm1.score(model, probe) - reference_score) < 1e-5 ), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score)