diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index 4b7310c49c118dce3fe5538c3378c2138d3b8463..672ed2309b386db0571085f23f3b2aca7ea3d230 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -30,8 +30,6 @@ from bob.learn.em.mixture import linear_scoring logger = logging.getLogger(__name__) -# from bob.pipelines import ToDaskBag # Used when switching from samples to da.Array - class GMM(BioAlgorithm, BaseEstimator): """Algorithm for computing UBM and Gaussian Mixture Models of the features. @@ -109,7 +107,7 @@ class GMM(BioAlgorithm, BaseEstimator): Function returning a score from a model, a UBM, and a probe. """ - # copy parameters + # Copy parameters self.number_of_gaussians = number_of_gaussians self.kmeans_training_iterations = kmeans_training_iterations self.ubm_training_iterations = ubm_training_iterations @@ -148,7 +146,7 @@ class GMM(BioAlgorithm, BaseEstimator): ) def save_model(self, ubm_file): - """Saves the projector to file.""" + """Saves the projector (UBM) to file.""" # Saves the UBM to file logger.debug("Saving model to file '%s'", ubm_file) @@ -156,44 +154,39 @@ class GMM(BioAlgorithm, BaseEstimator): self.ubm.save(hdf5) def load_model(self, ubm_file): - """Loads the projector from a file.""" + """Loads the projector (UBM) from a file.""" hdf5file = HDF5File(ubm_file, "r") logger.debug("Loading model from file '%s'", ubm_file) - # Read UBM + # Read the UBM self.ubm = GMMMachine.from_hdf5(hdf5file) self.ubm.variance_thresholds = self.variance_threshold def project(self, array): - """Computes GMM statistics against a UBM, given a 2D array of feature vectors""" + """Computes GMM statistics against a UBM, given a 2D array of feature vectors + + This is applied to the probes before scoring. + """ self._check_feature(array) logger.debug("Projecting %d feature vectors", array.shape[0]) # Accumulates statistics gmm_stats = self.ubm.transform(array) gmm_stats.compute() - # return the resulting statistics + # Return the resulting statistics return gmm_stats - def read_feature(self, feature_file): - """Read the type of features that we require, namely GMM_Stats""" - return GMMStats.from_hdf5(HDF5File(feature_file, "r")) - - def write_feature(self, feature, feature_file): - """Write the features (GMM_Stats)""" - return feature.save(feature_file) - def enroll(self, data): """Enrolls a GMM using MAP adaptation given a reference's feature vectors - Returns a GMMMachine tweaked from the UBM with MAP + Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data. """ [self._check_feature(feature) for feature in data] array = da.vstack(data) # Use the array to train a GMM and return it - logger.debug(" .... Enrolling with %d feature vectors", array.shape[0]) + logger.info("Enrolling with %d feature vectors", array.shape[0]) - # TODO responsibility_threshold + # TODO accept responsibility_threshold in bob.learn.em with dask.config.set(scheduler="threads"): gmm = GMMMachine( n_gaussians=self.number_of_gaussians, @@ -205,18 +198,21 @@ class GMM(BioAlgorithm, BaseEstimator): update_means=self.enroll_update_means, update_variances=self.enroll_update_variances, update_weights=self.enroll_update_weights, + mean_var_update_threshold=self.variance_threshold, ) - gmm.variance_thresholds = self.variance_threshold gmm.fit(array) return gmm def read_biometric_reference(self, model_file): - """Reads an enrolled reference model, which is a MAP GMMMachine""" + """Reads an enrolled reference model, which is a MAP GMMMachine.""" + if self.ubm is None: + raise ValueError( + "You must load a UBM before reading a biometric reference." + ) return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm) - @classmethod - def write_biometric_reference(cls, model: GMMMachine, model_file): - """Write the enrolled reference (MAP GMMMachine)""" + def write_biometric_reference(self, model: GMMMachine, model_file): + """Write the enrolled reference (MAP GMMMachine) into a file.""" return model.save(model_file) def score(self, biometric_reference: GMMMachine, probe): @@ -307,6 +303,7 @@ class GMM(BioAlgorithm, BaseEstimator): update_means=self.update_means, update_variances=self.update_variances, update_weights=self.update_weights, + mean_var_update_threshold=self.variance_threshold, k_means_trainer=KMeansMachine( self.number_of_gaussians, convergence_threshold=self.training_threshold, diff --git a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 index 4dba2c87521088f9ae6a3a99ba1170768ebacdc2..2e6e337f592e53f56806bce07dd5556c3176e6ae 100644 Binary files a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 and b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5 index 602a4184ae4fd4232a24db35aee2525e87558e80..84437324be483253b54aff3fc2d0f1c2e1e620de 100644 Binary files a/bob/bio/gmm/test/data/gmm_projected.hdf5 and b/bob/bio/gmm/test/data/gmm_projected.hdf5 differ diff --git a/bob/bio/gmm/test/data/gmm_ubm.hdf5 b/bob/bio/gmm/test/data/gmm_ubm.hdf5 index 50b42a96d67e594f3f80d869759f8a135f7a937c..99df686343ee83dc8332c56561675cb91fd98ad8 100644 Binary files a/bob/bio/gmm/test/data/gmm_ubm.hdf5 and b/bob/bio/gmm/test/data/gmm_ubm.hdf5 differ diff --git a/bob/bio/gmm/test/test_gmm.py b/bob/bio/gmm/test/test_gmm.py index e60ec31e5d6754686f4ab2a507e8be49332b7b95..404becce825f821e4637c44bb5effd8cdb314bd4 100644 --- a/bob/bio/gmm/test/test_gmm.py +++ b/bob/bio/gmm/test/test_gmm.py @@ -18,7 +18,6 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import logging -import os import tempfile import pkg_resources @@ -28,6 +27,7 @@ import bob.bio.gmm from bob.bio.base.test import utils from bob.bio.gmm.algorithm import GMM from bob.learn.em.mixture import GMMMachine +from bob.learn.em.mixture import GMMStats logger = logging.getLogger(__name__) @@ -50,6 +50,7 @@ def test_class(): def test_training(): """Tests the generation of the UBM.""" + # Set a small training iteration count gmm1 = GMM( number_of_gaussians=2, kmeans_training_iterations=1, @@ -59,24 +60,26 @@ def test_training(): train_data = utils.random_training_set( (100, 45), count=5, minimum=-5.0, maximum=5.0 ) - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/gmm_ubm.hdf5" - ) - # Train the projector + # Train the UBM (projector) gmm1.fit(train_data) + # Test saving and loading of projector with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_model.hdf5") as fd: temp_file = fd.name gmm1.save_model(temp_file) - assert os.path.exists(temp_file) - + reference_file = pkg_resources.resource_filename( + "bob.bio.gmm.test", "data/gmm_ubm.hdf5" + ) if regenerate_refs: gmm1.save_model(reference_file) - gmm1.ubm = GMMMachine.from_hdf5(reference_file) - assert gmm1.ubm.is_similar_to(GMMMachine.from_hdf5(temp_file)) + gmm2 = GMM(number_of_gaussians=2) + + gmm2.load_model(temp_file) + ubm_reference = GMMMachine.from_hdf5(reference_file) + assert gmm2.ubm.is_similar_to(ubm_reference) def test_projector(): @@ -92,14 +95,13 @@ def test_projector(): projected = gmm1.project(feature) assert isinstance(projected, bob.learn.em.mixture.GMMStats) - reference_path = pkg_resources.resource_filename( + reference_file = pkg_resources.resource_filename( "bob.bio.gmm.test", "data/gmm_projected.hdf5" ) - if regenerate_refs: - projected.save(reference_path) + projected.save(reference_file) - reference = gmm1.read_feature(reference_path) + reference = GMMStats.from_hdf5(reference_file) assert projected.is_similar_to(reference) @@ -122,24 +124,23 @@ def test_enroll(): reference_file = pkg_resources.resource_filename( "bob.bio.gmm.test", "data/gmm_enrolled.hdf5" ) - if regenerate_refs: biometric_reference.save(reference_file) - gmm2 = GMMMachine.from_hdf5(reference_file, ubm=ubm) + gmm2 = gmm1.read_biometric_reference(reference_file) assert biometric_reference.is_similar_to(gmm2) def test_score(): gmm1 = GMM(number_of_gaussians=2) - gmm1.ubm = GMMMachine.from_hdf5( + gmm1.load_model( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5") ) biometric_reference = GMMMachine.from_hdf5( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"), ubm=gmm1.ubm, ) - probe = gmm1.read_feature( + probe = GMMStats.from_hdf5( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") )