From a451243afc5c0a8d2f63f44af0940ee15d05797a Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Thu, 17 Feb 2022 19:38:15 +0100 Subject: [PATCH] Use load_model and read_biometric_reference --- bob/bio/gmm/algorithm/GMM.py | 45 +++++++++++------------ bob/bio/gmm/test/data/gmm_enrolled.hdf5 | Bin 12920 -> 12920 bytes bob/bio/gmm/test/data/gmm_projected.hdf5 | Bin 10608 -> 10608 bytes bob/bio/gmm/test/data/gmm_ubm.hdf5 | Bin 12920 -> 12920 bytes bob/bio/gmm/test/test_gmm.py | 35 +++++++++--------- 5 files changed, 39 insertions(+), 41 deletions(-) diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index 4b7310c..672ed23 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -30,8 +30,6 @@ from bob.learn.em.mixture import linear_scoring logger = logging.getLogger(__name__) -# from bob.pipelines import ToDaskBag # Used when switching from samples to da.Array - class GMM(BioAlgorithm, BaseEstimator): """Algorithm for computing UBM and Gaussian Mixture Models of the features. @@ -109,7 +107,7 @@ class GMM(BioAlgorithm, BaseEstimator): Function returning a score from a model, a UBM, and a probe. """ - # copy parameters + # Copy parameters self.number_of_gaussians = number_of_gaussians self.kmeans_training_iterations = kmeans_training_iterations self.ubm_training_iterations = ubm_training_iterations @@ -148,7 +146,7 @@ class GMM(BioAlgorithm, BaseEstimator): ) def save_model(self, ubm_file): - """Saves the projector to file.""" + """Saves the projector (UBM) to file.""" # Saves the UBM to file logger.debug("Saving model to file '%s'", ubm_file) @@ -156,44 +154,39 @@ class GMM(BioAlgorithm, BaseEstimator): self.ubm.save(hdf5) def load_model(self, ubm_file): - """Loads the projector from a file.""" + """Loads the projector (UBM) from a file.""" hdf5file = HDF5File(ubm_file, "r") logger.debug("Loading model from file '%s'", ubm_file) - # Read UBM + # Read the UBM self.ubm = GMMMachine.from_hdf5(hdf5file) self.ubm.variance_thresholds = self.variance_threshold def project(self, array): - """Computes GMM statistics against a UBM, given a 2D array of feature vectors""" + """Computes GMM statistics against a UBM, given a 2D array of feature vectors + + This is applied to the probes before scoring. + """ self._check_feature(array) logger.debug("Projecting %d feature vectors", array.shape[0]) # Accumulates statistics gmm_stats = self.ubm.transform(array) gmm_stats.compute() - # return the resulting statistics + # Return the resulting statistics return gmm_stats - def read_feature(self, feature_file): - """Read the type of features that we require, namely GMM_Stats""" - return GMMStats.from_hdf5(HDF5File(feature_file, "r")) - - def write_feature(self, feature, feature_file): - """Write the features (GMM_Stats)""" - return feature.save(feature_file) - def enroll(self, data): """Enrolls a GMM using MAP adaptation given a reference's feature vectors - Returns a GMMMachine tweaked from the UBM with MAP + Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data. """ [self._check_feature(feature) for feature in data] array = da.vstack(data) # Use the array to train a GMM and return it - logger.debug(" .... Enrolling with %d feature vectors", array.shape[0]) + logger.info("Enrolling with %d feature vectors", array.shape[0]) - # TODO responsibility_threshold + # TODO accept responsibility_threshold in bob.learn.em with dask.config.set(scheduler="threads"): gmm = GMMMachine( n_gaussians=self.number_of_gaussians, @@ -205,18 +198,21 @@ class GMM(BioAlgorithm, BaseEstimator): update_means=self.enroll_update_means, update_variances=self.enroll_update_variances, update_weights=self.enroll_update_weights, + mean_var_update_threshold=self.variance_threshold, ) - gmm.variance_thresholds = self.variance_threshold gmm.fit(array) return gmm def read_biometric_reference(self, model_file): - """Reads an enrolled reference model, which is a MAP GMMMachine""" + """Reads an enrolled reference model, which is a MAP GMMMachine.""" + if self.ubm is None: + raise ValueError( + "You must load a UBM before reading a biometric reference." + ) return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm) - @classmethod - def write_biometric_reference(cls, model: GMMMachine, model_file): - """Write the enrolled reference (MAP GMMMachine)""" + def write_biometric_reference(self, model: GMMMachine, model_file): + """Write the enrolled reference (MAP GMMMachine) into a file.""" return model.save(model_file) def score(self, biometric_reference: GMMMachine, probe): @@ -307,6 +303,7 @@ class GMM(BioAlgorithm, BaseEstimator): update_means=self.update_means, update_variances=self.update_variances, update_weights=self.update_weights, + mean_var_update_threshold=self.variance_threshold, k_means_trainer=KMeansMachine( self.number_of_gaussians, convergence_threshold=self.training_threshold, diff --git a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 b/bob/bio/gmm/test/data/gmm_enrolled.hdf5 index 4dba2c87521088f9ae6a3a99ba1170768ebacdc2..2e6e337f592e53f56806bce07dd5556c3176e6ae 100644 GIT binary patch delta 159 zcmey7@*`!#9A+lQDVygphcPiRPMJK9dD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7 z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5Ff>;^5hAqHux}Y_#S)77#s84{Ebs T(PRZ?jmZ-XLqPiHF;4;jd9F9= delta 159 zcmey7@*`!#9A>7|wVUTLhcPjouAMxOdD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7 z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5Ff>;^5hAqHux}Y_#S)77#s84{Ebs T(PRZ?jmZ-XLqPiHF;4;jR%k&1 diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5 index 602a4184ae4fd4232a24db35aee2525e87558e80..84437324be483253b54aff3fc2d0f1c2e1e620de 100644 GIT binary patch delta 124 zcmewm^dV@&9A+klDVygphcPiROqo27dD3JBW~Iri*n&X9|JddUg6Me?VvJDfi5ECF z2S~|)xSR8&)fm|srtl>(FjP!DSUXvPS!lABY!_50W8%SFu#lA;TxjFLSjNp73LcyQ D`hF&5 delta 124 zcmewm^dV@&9A>6dwVUTLhcPjos+~NKdD3JBW~Iri*n&X9|JddUg6Me?VvJDfi5ECF z2S~|)xSR8&)fm}N)$%1VFjP!DSUXvPS!lABY!_50W8%SFu#lA;TxjFLSjNp73LcyQ Dza1>F diff --git a/bob/bio/gmm/test/data/gmm_ubm.hdf5 b/bob/bio/gmm/test/data/gmm_ubm.hdf5 index 50b42a96d67e594f3f80d869759f8a135f7a937c..99df686343ee83dc8332c56561675cb91fd98ad8 100644 GIT binary patch delta 179 zcmey7@*`!#9A+klDVygphcPiROqo27dD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7 z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5FgM)=;?2%JFCA$4kCR4)&WDa;P#* qHeg}ftfz4fWX9%yn)6sd*3Q#|x<;>PvI4Wl<Ozl$AT#GNPXYjS^*{Ik delta 179 zcmey7@*`!#9A>6dwVUTLhcPjos+~NKdD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7 z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5FgM)=;?2%E15z8*DZ&<WOarY{0^{ mSx@5}$c)YZH0QB^tevL^b&X!pWCdo8$rB7iKxWQko&*5;I7MXu diff --git a/bob/bio/gmm/test/test_gmm.py b/bob/bio/gmm/test/test_gmm.py index e60ec31..404becc 100644 --- a/bob/bio/gmm/test/test_gmm.py +++ b/bob/bio/gmm/test/test_gmm.py @@ -18,7 +18,6 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import logging -import os import tempfile import pkg_resources @@ -28,6 +27,7 @@ import bob.bio.gmm from bob.bio.base.test import utils from bob.bio.gmm.algorithm import GMM from bob.learn.em.mixture import GMMMachine +from bob.learn.em.mixture import GMMStats logger = logging.getLogger(__name__) @@ -50,6 +50,7 @@ def test_class(): def test_training(): """Tests the generation of the UBM.""" + # Set a small training iteration count gmm1 = GMM( number_of_gaussians=2, kmeans_training_iterations=1, @@ -59,24 +60,26 @@ def test_training(): train_data = utils.random_training_set( (100, 45), count=5, minimum=-5.0, maximum=5.0 ) - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/gmm_ubm.hdf5" - ) - # Train the projector + # Train the UBM (projector) gmm1.fit(train_data) + # Test saving and loading of projector with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_model.hdf5") as fd: temp_file = fd.name gmm1.save_model(temp_file) - assert os.path.exists(temp_file) - + reference_file = pkg_resources.resource_filename( + "bob.bio.gmm.test", "data/gmm_ubm.hdf5" + ) if regenerate_refs: gmm1.save_model(reference_file) - gmm1.ubm = GMMMachine.from_hdf5(reference_file) - assert gmm1.ubm.is_similar_to(GMMMachine.from_hdf5(temp_file)) + gmm2 = GMM(number_of_gaussians=2) + + gmm2.load_model(temp_file) + ubm_reference = GMMMachine.from_hdf5(reference_file) + assert gmm2.ubm.is_similar_to(ubm_reference) def test_projector(): @@ -92,14 +95,13 @@ def test_projector(): projected = gmm1.project(feature) assert isinstance(projected, bob.learn.em.mixture.GMMStats) - reference_path = pkg_resources.resource_filename( + reference_file = pkg_resources.resource_filename( "bob.bio.gmm.test", "data/gmm_projected.hdf5" ) - if regenerate_refs: - projected.save(reference_path) + projected.save(reference_file) - reference = gmm1.read_feature(reference_path) + reference = GMMStats.from_hdf5(reference_file) assert projected.is_similar_to(reference) @@ -122,24 +124,23 @@ def test_enroll(): reference_file = pkg_resources.resource_filename( "bob.bio.gmm.test", "data/gmm_enrolled.hdf5" ) - if regenerate_refs: biometric_reference.save(reference_file) - gmm2 = GMMMachine.from_hdf5(reference_file, ubm=ubm) + gmm2 = gmm1.read_biometric_reference(reference_file) assert biometric_reference.is_similar_to(gmm2) def test_score(): gmm1 = GMM(number_of_gaussians=2) - gmm1.ubm = GMMMachine.from_hdf5( + gmm1.load_model( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5") ) biometric_reference = GMMMachine.from_hdf5( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"), ubm=gmm1.ubm, ) - probe = gmm1.read_feature( + probe = GMMStats.from_hdf5( pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") ) -- GitLab