diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index e33fc22fd4aa87e47b6c8e3dacda42bcbc8bda64..423fbc516e32fdcc4b53c66b874827bddfec2470 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -24,10 +24,10 @@ from h5py import File as HDF5File from sklearn.base import BaseEstimator from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm -from bob.learn.em.cluster import KMeansMachine -from bob.learn.em.mixture import GMMMachine -from bob.learn.em.mixture import GMMStats -from bob.learn.em.mixture import linear_scoring +from bob.learn.em import KMeansMachine +from bob.learn.em import GMMMachine +from bob.learn.em import GMMStats +from bob.learn.em import linear_scoring logger = logging.getLogger(__name__) @@ -51,19 +51,20 @@ class GMM(BioAlgorithm, BaseEstimator): # parameters of UBM training kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means kmeans_init_iterations: Union[int,None] = None, # Maximum number of iterations for K-Means init + kmeans_oversampling_factor: int = 64, ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training training_threshold: float = 5e-4, # Threshold to end the ML training variance_threshold: float = 5e-4, # Minimum value that a variance can reach update_means: bool = True, update_variances: bool = True, update_weights: bool = True, - # parameters of the GMM enrollment - gmm_enroll_iterations: int = 1, # Number of iterations for the enrollment phase + # parameters of the GMM enrollment (MAP) + gmm_enroll_iterations: int = 1, enroll_update_means: bool = True, enroll_update_variances: bool = False, enroll_update_weights: bool = False, - relevance_factor: float = 4, # Relevance factor as described in Reynolds paper - responsibility_threshold: float = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance. + enroll_relevance_factor: Union[float, None] = 4, + enroll_alpha: float = 0.5, # scoring scoring_function: Callable = linear_scoring, # RNG @@ -80,6 +81,8 @@ class GMM(BioAlgorithm, BaseEstimator): kmeans_init_iterations Number of iterations used for setting the k-means initial centroids. if None, will use the same as kmeans_training_iterations. + kmeans_oversampling_factor + Oversampling factor used by k-means initializer. ubm_training_iterations Number of e-m iterations for training the UBM. training_threshold @@ -100,12 +103,11 @@ class GMM(BioAlgorithm, BaseEstimator): Decides wether the means of the Gaussians are updated while enrolling. enroll_update_variances Decides wether the variancess of the Gaussians are updated while enrolling. - relevance_factor - Relevance factor as described in Reynolds paper. - responsibility_threshold - If set, the weight of a particular Gaussian will at least be greater than - this threshold. In the case where the real weight is lower, the prior mean - value will be used to estimate the current mean and variance. + enroll_relevance_factor + For enrollment: MAP relevance factor as described in Reynolds paper. + If None, will not apply Reynolds adaptation. + enroll_alpha + For enrollment: MAP adaptation coefficient. init_seed Seed for the random number generation. scoring_function @@ -115,21 +117,22 @@ class GMM(BioAlgorithm, BaseEstimator): # Copy parameters self.number_of_gaussians = number_of_gaussians self.kmeans_training_iterations = kmeans_training_iterations - self.kmeans_init_iterations = kmeans_init_iterations or kmeans_training_iterations + self.kmeans_init_iterations = kmeans_training_iterations if kmeans_init_iterations is None else kmeans_init_iterations + self.kmeans_oversampling_factor = kmeans_oversampling_factor self.ubm_training_iterations = ubm_training_iterations self.training_threshold = training_threshold self.variance_threshold = variance_threshold self.update_weights = update_weights self.update_means = update_means self.update_variances = update_variances - self.relevance_factor = relevance_factor + self.enroll_relevance_factor = enroll_relevance_factor + self.enroll_alpha = enroll_alpha self.gmm_enroll_iterations = gmm_enroll_iterations self.enroll_update_means = enroll_update_means self.enroll_update_weights = enroll_update_weights self.enroll_update_variances = enroll_update_variances self.init_seed = init_seed self.rng = self.init_seed - self.responsibility_threshold = responsibility_threshold self.scoring_function = scoring_function @@ -192,7 +195,6 @@ class GMM(BioAlgorithm, BaseEstimator): # Use the array to train a GMM and return it logger.info("Enrolling with %d feature vectors", array.shape[0]) - # TODO accept responsibility_threshold in bob.learn.em with dask.config.set(scheduler="threads"): gmm = GMMMachine( n_gaussians=self.number_of_gaussians, @@ -205,13 +207,8 @@ class GMM(BioAlgorithm, BaseEstimator): update_variances=self.enroll_update_variances, update_weights=self.enroll_update_weights, mean_var_update_threshold=self.variance_threshold, - k_means_trainer= KMeansMachine( - n_clusters=self.number_of_gaussians, - init_method="k-means||", - max_iter=self.kmeans_training_iterations, - init_max_iter=self.kmeans_init_iterations, - convergence_threshold=self.training_threshold, - ) + relevance_factor=self.enroll_relevance_factor, + alpha=self.enroll_alpha, ) gmm.fit(array) return gmm @@ -317,7 +314,7 @@ class GMM(BioAlgorithm, BaseEstimator): convergence_threshold=self.training_threshold, max_iter=self.kmeans_training_iterations, init_method="k-means||", - init_max_iter=5, + init_max_iter=self.kmeans_init_iterations, random_state=self.init_seed, ), ) @@ -332,10 +329,21 @@ class GMM(BioAlgorithm, BaseEstimator): def transform(self, X, **kwargs): """Passthrough. Enroll applies a different transform as score.""" # The idea would be to apply the projection in Transform (going from extracted - # to GMMStats), but we must not apply this during the training (fit requires - # extracted data directly). + # to GMMStats), but we must not apply this during the training or enrollment + # (those require extracted data directly, not projected). # `project` is applied in the score function directly. return X + @classmethod + def custom_enrolled_save_fn(cls, data, path): + data.save(path) + + def custom_enrolled_load_fn(self, path): + return GMMMachine.from_hdf5(path, ubm=self.ubm) + def _more_tags(self): - return {"bob_fit_supports_dask_array": True} + return { + "bob_fit_supports_dask_array": True, + "bob_enrolled_save_fn": self.custom_enrolled_save_fn, + "bob_enrolled_load_fn": self.custom_enrolled_load_fn, + } diff --git a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 index 4b48c0fa3c6376c2e5c41aa164dac481c847e380..94190e2b4c1cfb7bdf36627b5fd770fb9079ded0 100644 Binary files a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 and b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5 index 1bdbba5769c5a7195efe81d6b0b14685ef980097..894d3f7e6594fab0c9720dd71b7f5082f338ad4e 100644 Binary files a/bob/bio/gmm/test/data/gmm_projected.hdf5 and b/bob/bio/gmm/test/data/gmm_projected.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_ubm.hdf5 b/bob/bio/gmm/test/data/gmm_ubm.hdf5 index 780baa0277bd2d7cf68ab048c7df69900fa6b875..6f40b3b024d24a0349f122864a1b3c254b81955f 100644 Binary files a/bob/bio/gmm/test/data/gmm_ubm.hdf5 and b/bob/bio/gmm/test/data/gmm_ubm.hdf5 differ diff --git a/bob/bio/gmm/test/test_gmm.py b/bob/bio/gmm/test/test_gmm.py index ea7b6a4c9cc21853924d5eb8fbabb6d8f70d3da9..4a9acb404149b5060a3af453e74dc21a63d89008 100644 --- a/bob/bio/gmm/test/test_gmm.py +++ b/bob/bio/gmm/test/test_gmm.py @@ -28,8 +28,8 @@ import bob.bio.gmm from bob.bio.base.test import utils from bob.bio.gmm.algorithm import GMM -from bob.learn.em.mixture.gmm import GMMMachine -from bob.learn.em.mixture.gmm import GMMStats +from bob.learn.em import GMMMachine +from bob.learn.em import GMMStats logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ def test_projector(): # Generate and project random feature feature = utils.random_array((20, 45), -5.0, 5.0, seed=seed_value) projected = gmm1.project(feature) - assert isinstance(projected, bob.learn.em.mixture.GMMStats) + assert isinstance(projected, GMMStats) reference_file = pkg_resources.resource_filename( "bob.bio.gmm.test", "data/gmm_projected.hdf5" @@ -137,8 +137,8 @@ def test_enroll(): with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_bioref.hdf5") as fd: temp_file = fd.name - gmm1.write_biometric_reference(biometric_reference, reference_file) - assert os.path.exists(temp_file) + gmm1.write_biometric_reference(biometric_reference, temp_file) + assert GMMMachine.from_hdf5(temp_file, ubm).is_similar_to(gmm2) def test_score():