From a451243afc5c0a8d2f63f44af0940ee15d05797a Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Thu, 17 Feb 2022 19:38:15 +0100
Subject: [PATCH] Use load_model and read_biometric_reference

---
 bob/bio/gmm/algorithm/GMM.py             |  45 +++++++++++------------
 bob/bio/gmm/test/data/gmm_enrolled.hdf5  | Bin 12920 -> 12920 bytes
 bob/bio/gmm/test/data/gmm_projected.hdf5 | Bin 10608 -> 10608 bytes
 bob/bio/gmm/test/data/gmm_ubm.hdf5       | Bin 12920 -> 12920 bytes
 bob/bio/gmm/test/test_gmm.py             |  35 +++++++++---------
 5 files changed, 39 insertions(+), 41 deletions(-)

diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py
index 4b7310c..672ed23 100644
--- a/bob/bio/gmm/algorithm/GMM.py
+++ b/bob/bio/gmm/algorithm/GMM.py
@@ -30,8 +30,6 @@ from bob.learn.em.mixture import linear_scoring
 
 logger = logging.getLogger(__name__)
 
-# from bob.pipelines import ToDaskBag  # Used when switching from samples to da.Array
-
 
 class GMM(BioAlgorithm, BaseEstimator):
     """Algorithm for computing UBM and Gaussian Mixture Models of the features.
@@ -109,7 +107,7 @@ class GMM(BioAlgorithm, BaseEstimator):
             Function returning a score from a model, a UBM, and a probe.
         """
 
-        # copy parameters
+        # Copy parameters
         self.number_of_gaussians = number_of_gaussians
         self.kmeans_training_iterations = kmeans_training_iterations
         self.ubm_training_iterations = ubm_training_iterations
@@ -148,7 +146,7 @@ class GMM(BioAlgorithm, BaseEstimator):
             )
 
     def save_model(self, ubm_file):
-        """Saves the projector to file."""
+        """Saves the projector (UBM) to file."""
         # Saves the UBM to file
         logger.debug("Saving model to file '%s'", ubm_file)
 
@@ -156,44 +154,39 @@ class GMM(BioAlgorithm, BaseEstimator):
         self.ubm.save(hdf5)
 
     def load_model(self, ubm_file):
-        """Loads the projector from a file."""
+        """Loads the projector (UBM) from a file."""
         hdf5file = HDF5File(ubm_file, "r")
         logger.debug("Loading model from file '%s'", ubm_file)
-        # Read UBM
+        # Read the UBM
         self.ubm = GMMMachine.from_hdf5(hdf5file)
         self.ubm.variance_thresholds = self.variance_threshold
 
     def project(self, array):
-        """Computes GMM statistics against a UBM, given a 2D array of feature vectors"""
+        """Computes GMM statistics against a UBM, given a 2D array of feature vectors
+
+        This is applied to the probes before scoring.
+        """
         self._check_feature(array)
         logger.debug("Projecting %d feature vectors", array.shape[0])
         # Accumulates statistics
         gmm_stats = self.ubm.transform(array)
         gmm_stats.compute()
 
-        # return the resulting statistics
+        # Return the resulting statistics
         return gmm_stats
 
-    def read_feature(self, feature_file):
-        """Read the type of features that we require, namely GMM_Stats"""
-        return GMMStats.from_hdf5(HDF5File(feature_file, "r"))
-
-    def write_feature(self, feature, feature_file):
-        """Write the features (GMM_Stats)"""
-        return feature.save(feature_file)
-
     def enroll(self, data):
         """Enrolls a GMM using MAP adaptation given a reference's feature vectors
 
-        Returns a GMMMachine tweaked from the UBM with MAP
+        Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
         """
 
         [self._check_feature(feature) for feature in data]
         array = da.vstack(data)
         # Use the array to train a GMM and return it
-        logger.debug(" .... Enrolling with %d feature vectors", array.shape[0])
+        logger.info("Enrolling with %d feature vectors", array.shape[0])
 
-        # TODO responsibility_threshold
+        # TODO accept responsibility_threshold in bob.learn.em
         with dask.config.set(scheduler="threads"):
             gmm = GMMMachine(
                 n_gaussians=self.number_of_gaussians,
@@ -205,18 +198,21 @@ class GMM(BioAlgorithm, BaseEstimator):
                 update_means=self.enroll_update_means,
                 update_variances=self.enroll_update_variances,
                 update_weights=self.enroll_update_weights,
+                mean_var_update_threshold=self.variance_threshold,
             )
-            gmm.variance_thresholds = self.variance_threshold
             gmm.fit(array)
         return gmm
 
     def read_biometric_reference(self, model_file):
-        """Reads an enrolled reference model, which is a MAP GMMMachine"""
+        """Reads an enrolled reference model, which is a MAP GMMMachine."""
+        if self.ubm is None:
+            raise ValueError(
+                "You must load a UBM before reading a biometric reference."
+            )
         return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm)
 
-    @classmethod
-    def write_biometric_reference(cls, model: GMMMachine, model_file):
-        """Write the enrolled reference (MAP GMMMachine)"""
+    def write_biometric_reference(self, model: GMMMachine, model_file):
+        """Write the enrolled reference (MAP GMMMachine) into a file."""
         return model.save(model_file)
 
     def score(self, biometric_reference: GMMMachine, probe):
@@ -307,6 +303,7 @@ class GMM(BioAlgorithm, BaseEstimator):
             update_means=self.update_means,
             update_variances=self.update_variances,
             update_weights=self.update_weights,
+            mean_var_update_threshold=self.variance_threshold,
             k_means_trainer=KMeansMachine(
                 self.number_of_gaussians,
                 convergence_threshold=self.training_threshold,
diff --git a/bob/bio/gmm/test/data/gmm_enrolled.hdf5 b/bob/bio/gmm/test/data/gmm_enrolled.hdf5
index 4dba2c87521088f9ae6a3a99ba1170768ebacdc2..2e6e337f592e53f56806bce07dd5556c3176e6ae 100644
GIT binary patch
delta 159
zcmey7@*`!#9A+lQDVygphcPiRPMJK9dD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7
z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5Ff>;^5hAqHux}Y_#S)77#s84{Ebs
T(PRZ?jmZ-XLqPiHF;4;jd9F9=

delta 159
zcmey7@*`!#9A>7|wVUTLhcPjouAMxOdD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7
z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5Ff>;^5hAqHux}Y_#S)77#s84{Ebs
T(PRZ?jmZ-XLqPiHF;4;jR%k&1

diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5
index 602a4184ae4fd4232a24db35aee2525e87558e80..84437324be483253b54aff3fc2d0f1c2e1e620de 100644
GIT binary patch
delta 124
zcmewm^dV@&9A+klDVygphcPiROqo27dD3JBW~Iri*n&X9|JddUg6Me?VvJDfi5ECF
z2S~|)xSR8&)fm|srtl>(FjP!DSUXvPS!lABY!_50W8%SFu#lA;TxjFLSjNp73LcyQ
D`hF&5

delta 124
zcmewm^dV@&9A>6dwVUTLhcPjos+~NKdD3JBW~Iri*n&X9|JddUg6Me?VvJDfi5ECF
z2S~|)xSR8&)fm}N)$%1VFjP!DSUXvPS!lABY!_50W8%SFu#lA;TxjFLSjNp73LcyQ
Dza1>F

diff --git a/bob/bio/gmm/test/data/gmm_ubm.hdf5 b/bob/bio/gmm/test/data/gmm_ubm.hdf5
index 50b42a96d67e594f3f80d869759f8a135f7a937c..99df686343ee83dc8332c56561675cb91fd98ad8 100644
GIT binary patch
delta 179
zcmey7@*`!#9A+klDVygphcPiROqo27dD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7
z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5FgM)=;?2%JFCA$4kCR4)&WDa;P#*
qHeg}ftfz4fWX9%yn)6sd*3Q#|x<;>PvI4Wl<Ozl$AT#GNPXYjS^*{Ik

delta 179
zcmey7@*`!#9A>6dwVUTLhcPjos+~NKdD3JBW~Iri*n&X9|JarZLWMV5Na`>`*<dA7
z(qJWd(rRGgf0|8`3pBGPH_0l2RKd8D<e=QHi5FgM)=;?2%E15z8*DZ&<WOarY{0^{
mSx@5}$c)YZH0QB^tevL^b&X!pWCdo8$rB7iKxWQko&*5;I7MXu

diff --git a/bob/bio/gmm/test/test_gmm.py b/bob/bio/gmm/test/test_gmm.py
index e60ec31..404becc 100644
--- a/bob/bio/gmm/test/test_gmm.py
+++ b/bob/bio/gmm/test/test_gmm.py
@@ -18,7 +18,6 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 import logging
-import os
 import tempfile
 
 import pkg_resources
@@ -28,6 +27,7 @@ import bob.bio.gmm
 from bob.bio.base.test import utils
 from bob.bio.gmm.algorithm import GMM
 from bob.learn.em.mixture import GMMMachine
+from bob.learn.em.mixture import GMMStats
 
 logger = logging.getLogger(__name__)
 
@@ -50,6 +50,7 @@ def test_class():
 
 def test_training():
     """Tests the generation of the UBM."""
+    # Set a small training iteration count
     gmm1 = GMM(
         number_of_gaussians=2,
         kmeans_training_iterations=1,
@@ -59,24 +60,26 @@ def test_training():
     train_data = utils.random_training_set(
         (100, 45), count=5, minimum=-5.0, maximum=5.0
     )
-    reference_file = pkg_resources.resource_filename(
-        "bob.bio.gmm.test", "data/gmm_ubm.hdf5"
-    )
 
-    # Train the projector
+    # Train the UBM (projector)
     gmm1.fit(train_data)
 
+    # Test saving and loading of projector
     with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_model.hdf5") as fd:
         temp_file = fd.name
         gmm1.save_model(temp_file)
 
-        assert os.path.exists(temp_file)
-
+        reference_file = pkg_resources.resource_filename(
+            "bob.bio.gmm.test", "data/gmm_ubm.hdf5"
+        )
         if regenerate_refs:
             gmm1.save_model(reference_file)
 
-        gmm1.ubm = GMMMachine.from_hdf5(reference_file)
-        assert gmm1.ubm.is_similar_to(GMMMachine.from_hdf5(temp_file))
+        gmm2 = GMM(number_of_gaussians=2)
+
+        gmm2.load_model(temp_file)
+        ubm_reference = GMMMachine.from_hdf5(reference_file)
+        assert gmm2.ubm.is_similar_to(ubm_reference)
 
 
 def test_projector():
@@ -92,14 +95,13 @@ def test_projector():
     projected = gmm1.project(feature)
     assert isinstance(projected, bob.learn.em.mixture.GMMStats)
 
-    reference_path = pkg_resources.resource_filename(
+    reference_file = pkg_resources.resource_filename(
         "bob.bio.gmm.test", "data/gmm_projected.hdf5"
     )
-
     if regenerate_refs:
-        projected.save(reference_path)
+        projected.save(reference_file)
 
-    reference = gmm1.read_feature(reference_path)
+    reference = GMMStats.from_hdf5(reference_file)
     assert projected.is_similar_to(reference)
 
 
@@ -122,24 +124,23 @@ def test_enroll():
     reference_file = pkg_resources.resource_filename(
         "bob.bio.gmm.test", "data/gmm_enrolled.hdf5"
     )
-
     if regenerate_refs:
         biometric_reference.save(reference_file)
 
-    gmm2 = GMMMachine.from_hdf5(reference_file, ubm=ubm)
+    gmm2 = gmm1.read_biometric_reference(reference_file)
     assert biometric_reference.is_similar_to(gmm2)
 
 
 def test_score():
     gmm1 = GMM(number_of_gaussians=2)
-    gmm1.ubm = GMMMachine.from_hdf5(
+    gmm1.load_model(
         pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5")
     )
     biometric_reference = GMMMachine.from_hdf5(
         pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"),
         ubm=gmm1.ubm,
     )
-    probe = gmm1.read_feature(
+    probe = GMMStats.from_hdf5(
         pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5")
     )
 
-- 
GitLab