Skip to content
Snippets Groups Projects
Commit a451243a authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Use load_model and read_biometric_reference

parent 784f98e3
Branches
No related tags found
1 merge request!26Python implementation of GMM
Pipeline #58362 failed
...@@ -30,8 +30,6 @@ from bob.learn.em.mixture import linear_scoring ...@@ -30,8 +30,6 @@ from bob.learn.em.mixture import linear_scoring
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# from bob.pipelines import ToDaskBag # Used when switching from samples to da.Array
class GMM(BioAlgorithm, BaseEstimator): class GMM(BioAlgorithm, BaseEstimator):
"""Algorithm for computing UBM and Gaussian Mixture Models of the features. """Algorithm for computing UBM and Gaussian Mixture Models of the features.
...@@ -109,7 +107,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -109,7 +107,7 @@ class GMM(BioAlgorithm, BaseEstimator):
Function returning a score from a model, a UBM, and a probe. Function returning a score from a model, a UBM, and a probe.
""" """
# copy parameters # Copy parameters
self.number_of_gaussians = number_of_gaussians self.number_of_gaussians = number_of_gaussians
self.kmeans_training_iterations = kmeans_training_iterations self.kmeans_training_iterations = kmeans_training_iterations
self.ubm_training_iterations = ubm_training_iterations self.ubm_training_iterations = ubm_training_iterations
...@@ -148,7 +146,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -148,7 +146,7 @@ class GMM(BioAlgorithm, BaseEstimator):
) )
def save_model(self, ubm_file): def save_model(self, ubm_file):
"""Saves the projector to file.""" """Saves the projector (UBM) to file."""
# Saves the UBM to file # Saves the UBM to file
logger.debug("Saving model to file '%s'", ubm_file) logger.debug("Saving model to file '%s'", ubm_file)
...@@ -156,44 +154,39 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -156,44 +154,39 @@ class GMM(BioAlgorithm, BaseEstimator):
self.ubm.save(hdf5) self.ubm.save(hdf5)
def load_model(self, ubm_file): def load_model(self, ubm_file):
"""Loads the projector from a file.""" """Loads the projector (UBM) from a file."""
hdf5file = HDF5File(ubm_file, "r") hdf5file = HDF5File(ubm_file, "r")
logger.debug("Loading model from file '%s'", ubm_file) logger.debug("Loading model from file '%s'", ubm_file)
# Read UBM # Read the UBM
self.ubm = GMMMachine.from_hdf5(hdf5file) self.ubm = GMMMachine.from_hdf5(hdf5file)
self.ubm.variance_thresholds = self.variance_threshold self.ubm.variance_thresholds = self.variance_threshold
def project(self, array): def project(self, array):
"""Computes GMM statistics against a UBM, given a 2D array of feature vectors""" """Computes GMM statistics against a UBM, given a 2D array of feature vectors
This is applied to the probes before scoring.
"""
self._check_feature(array) self._check_feature(array)
logger.debug("Projecting %d feature vectors", array.shape[0]) logger.debug("Projecting %d feature vectors", array.shape[0])
# Accumulates statistics # Accumulates statistics
gmm_stats = self.ubm.transform(array) gmm_stats = self.ubm.transform(array)
gmm_stats.compute() gmm_stats.compute()
# return the resulting statistics # Return the resulting statistics
return gmm_stats return gmm_stats
def read_feature(self, feature_file):
"""Read the type of features that we require, namely GMM_Stats"""
return GMMStats.from_hdf5(HDF5File(feature_file, "r"))
def write_feature(self, feature, feature_file):
"""Write the features (GMM_Stats)"""
return feature.save(feature_file)
def enroll(self, data): def enroll(self, data):
"""Enrolls a GMM using MAP adaptation given a reference's feature vectors """Enrolls a GMM using MAP adaptation given a reference's feature vectors
Returns a GMMMachine tweaked from the UBM with MAP Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
""" """
[self._check_feature(feature) for feature in data] [self._check_feature(feature) for feature in data]
array = da.vstack(data) array = da.vstack(data)
# Use the array to train a GMM and return it # Use the array to train a GMM and return it
logger.debug(" .... Enrolling with %d feature vectors", array.shape[0]) logger.info("Enrolling with %d feature vectors", array.shape[0])
# TODO responsibility_threshold # TODO accept responsibility_threshold in bob.learn.em
with dask.config.set(scheduler="threads"): with dask.config.set(scheduler="threads"):
gmm = GMMMachine( gmm = GMMMachine(
n_gaussians=self.number_of_gaussians, n_gaussians=self.number_of_gaussians,
...@@ -205,18 +198,21 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -205,18 +198,21 @@ class GMM(BioAlgorithm, BaseEstimator):
update_means=self.enroll_update_means, update_means=self.enroll_update_means,
update_variances=self.enroll_update_variances, update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights, update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
) )
gmm.variance_thresholds = self.variance_threshold
gmm.fit(array) gmm.fit(array)
return gmm return gmm
def read_biometric_reference(self, model_file): def read_biometric_reference(self, model_file):
"""Reads an enrolled reference model, which is a MAP GMMMachine""" """Reads an enrolled reference model, which is a MAP GMMMachine."""
if self.ubm is None:
raise ValueError(
"You must load a UBM before reading a biometric reference."
)
return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm) return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm)
@classmethod def write_biometric_reference(self, model: GMMMachine, model_file):
def write_biometric_reference(cls, model: GMMMachine, model_file): """Write the enrolled reference (MAP GMMMachine) into a file."""
"""Write the enrolled reference (MAP GMMMachine)"""
return model.save(model_file) return model.save(model_file)
def score(self, biometric_reference: GMMMachine, probe): def score(self, biometric_reference: GMMMachine, probe):
...@@ -307,6 +303,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -307,6 +303,7 @@ class GMM(BioAlgorithm, BaseEstimator):
update_means=self.update_means, update_means=self.update_means,
update_variances=self.update_variances, update_variances=self.update_variances,
update_weights=self.update_weights, update_weights=self.update_weights,
mean_var_update_threshold=self.variance_threshold,
k_means_trainer=KMeansMachine( k_means_trainer=KMeansMachine(
self.number_of_gaussians, self.number_of_gaussians,
convergence_threshold=self.training_threshold, convergence_threshold=self.training_threshold,
......
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
import os
import tempfile import tempfile
import pkg_resources import pkg_resources
...@@ -28,6 +27,7 @@ import bob.bio.gmm ...@@ -28,6 +27,7 @@ import bob.bio.gmm
from bob.bio.base.test import utils from bob.bio.base.test import utils
from bob.bio.gmm.algorithm import GMM from bob.bio.gmm.algorithm import GMM
from bob.learn.em.mixture import GMMMachine from bob.learn.em.mixture import GMMMachine
from bob.learn.em.mixture import GMMStats
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,6 +50,7 @@ def test_class(): ...@@ -50,6 +50,7 @@ def test_class():
def test_training(): def test_training():
"""Tests the generation of the UBM.""" """Tests the generation of the UBM."""
# Set a small training iteration count
gmm1 = GMM( gmm1 = GMM(
number_of_gaussians=2, number_of_gaussians=2,
kmeans_training_iterations=1, kmeans_training_iterations=1,
...@@ -59,24 +60,26 @@ def test_training(): ...@@ -59,24 +60,26 @@ def test_training():
train_data = utils.random_training_set( train_data = utils.random_training_set(
(100, 45), count=5, minimum=-5.0, maximum=5.0 (100, 45), count=5, minimum=-5.0, maximum=5.0
) )
reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_ubm.hdf5"
)
# Train the projector # Train the UBM (projector)
gmm1.fit(train_data) gmm1.fit(train_data)
# Test saving and loading of projector
with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_model.hdf5") as fd: with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_model.hdf5") as fd:
temp_file = fd.name temp_file = fd.name
gmm1.save_model(temp_file) gmm1.save_model(temp_file)
assert os.path.exists(temp_file) reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_ubm.hdf5"
)
if regenerate_refs: if regenerate_refs:
gmm1.save_model(reference_file) gmm1.save_model(reference_file)
gmm1.ubm = GMMMachine.from_hdf5(reference_file) gmm2 = GMM(number_of_gaussians=2)
assert gmm1.ubm.is_similar_to(GMMMachine.from_hdf5(temp_file))
gmm2.load_model(temp_file)
ubm_reference = GMMMachine.from_hdf5(reference_file)
assert gmm2.ubm.is_similar_to(ubm_reference)
def test_projector(): def test_projector():
...@@ -92,14 +95,13 @@ def test_projector(): ...@@ -92,14 +95,13 @@ def test_projector():
projected = gmm1.project(feature) projected = gmm1.project(feature)
assert isinstance(projected, bob.learn.em.mixture.GMMStats) assert isinstance(projected, bob.learn.em.mixture.GMMStats)
reference_path = pkg_resources.resource_filename( reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_projected.hdf5" "bob.bio.gmm.test", "data/gmm_projected.hdf5"
) )
if regenerate_refs: if regenerate_refs:
projected.save(reference_path) projected.save(reference_file)
reference = gmm1.read_feature(reference_path) reference = GMMStats.from_hdf5(reference_file)
assert projected.is_similar_to(reference) assert projected.is_similar_to(reference)
...@@ -122,24 +124,23 @@ def test_enroll(): ...@@ -122,24 +124,23 @@ def test_enroll():
reference_file = pkg_resources.resource_filename( reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_enrolled.hdf5" "bob.bio.gmm.test", "data/gmm_enrolled.hdf5"
) )
if regenerate_refs: if regenerate_refs:
biometric_reference.save(reference_file) biometric_reference.save(reference_file)
gmm2 = GMMMachine.from_hdf5(reference_file, ubm=ubm) gmm2 = gmm1.read_biometric_reference(reference_file)
assert biometric_reference.is_similar_to(gmm2) assert biometric_reference.is_similar_to(gmm2)
def test_score(): def test_score():
gmm1 = GMM(number_of_gaussians=2) gmm1 = GMM(number_of_gaussians=2)
gmm1.ubm = GMMMachine.from_hdf5( gmm1.load_model(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5") pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5")
) )
biometric_reference = GMMMachine.from_hdf5( biometric_reference = GMMMachine.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"), pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"),
ubm=gmm1.ubm, ubm=gmm1.ubm,
) )
probe = gmm1.read_feature( probe = GMMStats.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5")
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment