Skip to content
Snippets Groups Projects
Commit a451243a authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Use load_model and read_biometric_reference

parent 784f98e3
No related branches found
No related tags found
1 merge request!26Python implementation of GMM
Pipeline #58362 failed
......@@ -30,8 +30,6 @@ from bob.learn.em.mixture import linear_scoring
logger = logging.getLogger(__name__)
# from bob.pipelines import ToDaskBag # Used when switching from samples to da.Array
class GMM(BioAlgorithm, BaseEstimator):
"""Algorithm for computing UBM and Gaussian Mixture Models of the features.
......@@ -109,7 +107,7 @@ class GMM(BioAlgorithm, BaseEstimator):
Function returning a score from a model, a UBM, and a probe.
"""
# copy parameters
# Copy parameters
self.number_of_gaussians = number_of_gaussians
self.kmeans_training_iterations = kmeans_training_iterations
self.ubm_training_iterations = ubm_training_iterations
......@@ -148,7 +146,7 @@ class GMM(BioAlgorithm, BaseEstimator):
)
def save_model(self, ubm_file):
"""Saves the projector to file."""
"""Saves the projector (UBM) to file."""
# Saves the UBM to file
logger.debug("Saving model to file '%s'", ubm_file)
......@@ -156,44 +154,39 @@ class GMM(BioAlgorithm, BaseEstimator):
self.ubm.save(hdf5)
def load_model(self, ubm_file):
"""Loads the projector from a file."""
"""Loads the projector (UBM) from a file."""
hdf5file = HDF5File(ubm_file, "r")
logger.debug("Loading model from file '%s'", ubm_file)
# Read UBM
# Read the UBM
self.ubm = GMMMachine.from_hdf5(hdf5file)
self.ubm.variance_thresholds = self.variance_threshold
def project(self, array):
"""Computes GMM statistics against a UBM, given a 2D array of feature vectors"""
"""Computes GMM statistics against a UBM, given a 2D array of feature vectors
This is applied to the probes before scoring.
"""
self._check_feature(array)
logger.debug("Projecting %d feature vectors", array.shape[0])
# Accumulates statistics
gmm_stats = self.ubm.transform(array)
gmm_stats.compute()
# return the resulting statistics
# Return the resulting statistics
return gmm_stats
def read_feature(self, feature_file):
"""Read the type of features that we require, namely GMM_Stats"""
return GMMStats.from_hdf5(HDF5File(feature_file, "r"))
def write_feature(self, feature, feature_file):
"""Write the features (GMM_Stats)"""
return feature.save(feature_file)
def enroll(self, data):
"""Enrolls a GMM using MAP adaptation given a reference's feature vectors
Returns a GMMMachine tweaked from the UBM with MAP
Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
"""
[self._check_feature(feature) for feature in data]
array = da.vstack(data)
# Use the array to train a GMM and return it
logger.debug(" .... Enrolling with %d feature vectors", array.shape[0])
logger.info("Enrolling with %d feature vectors", array.shape[0])
# TODO responsibility_threshold
# TODO accept responsibility_threshold in bob.learn.em
with dask.config.set(scheduler="threads"):
gmm = GMMMachine(
n_gaussians=self.number_of_gaussians,
......@@ -205,18 +198,21 @@ class GMM(BioAlgorithm, BaseEstimator):
update_means=self.enroll_update_means,
update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
)
gmm.variance_thresholds = self.variance_threshold
gmm.fit(array)
return gmm
def read_biometric_reference(self, model_file):
"""Reads an enrolled reference model, which is a MAP GMMMachine"""
"""Reads an enrolled reference model, which is a MAP GMMMachine."""
if self.ubm is None:
raise ValueError(
"You must load a UBM before reading a biometric reference."
)
return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm)
@classmethod
def write_biometric_reference(cls, model: GMMMachine, model_file):
"""Write the enrolled reference (MAP GMMMachine)"""
def write_biometric_reference(self, model: GMMMachine, model_file):
"""Write the enrolled reference (MAP GMMMachine) into a file."""
return model.save(model_file)
def score(self, biometric_reference: GMMMachine, probe):
......@@ -307,6 +303,7 @@ class GMM(BioAlgorithm, BaseEstimator):
update_means=self.update_means,
update_variances=self.update_variances,
update_weights=self.update_weights,
mean_var_update_threshold=self.variance_threshold,
k_means_trainer=KMeansMachine(
self.number_of_gaussians,
convergence_threshold=self.training_threshold,
......
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -18,7 +18,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import os
import tempfile
import pkg_resources
......@@ -28,6 +27,7 @@ import bob.bio.gmm
from bob.bio.base.test import utils
from bob.bio.gmm.algorithm import GMM
from bob.learn.em.mixture import GMMMachine
from bob.learn.em.mixture import GMMStats
logger = logging.getLogger(__name__)
......@@ -50,6 +50,7 @@ def test_class():
def test_training():
"""Tests the generation of the UBM."""
# Set a small training iteration count
gmm1 = GMM(
number_of_gaussians=2,
kmeans_training_iterations=1,
......@@ -59,24 +60,26 @@ def test_training():
train_data = utils.random_training_set(
(100, 45), count=5, minimum=-5.0, maximum=5.0
)
reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_ubm.hdf5"
)
# Train the projector
# Train the UBM (projector)
gmm1.fit(train_data)
# Test saving and loading of projector
with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_model.hdf5") as fd:
temp_file = fd.name
gmm1.save_model(temp_file)
assert os.path.exists(temp_file)
reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_ubm.hdf5"
)
if regenerate_refs:
gmm1.save_model(reference_file)
gmm1.ubm = GMMMachine.from_hdf5(reference_file)
assert gmm1.ubm.is_similar_to(GMMMachine.from_hdf5(temp_file))
gmm2 = GMM(number_of_gaussians=2)
gmm2.load_model(temp_file)
ubm_reference = GMMMachine.from_hdf5(reference_file)
assert gmm2.ubm.is_similar_to(ubm_reference)
def test_projector():
......@@ -92,14 +95,13 @@ def test_projector():
projected = gmm1.project(feature)
assert isinstance(projected, bob.learn.em.mixture.GMMStats)
reference_path = pkg_resources.resource_filename(
reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_projected.hdf5"
)
if regenerate_refs:
projected.save(reference_path)
projected.save(reference_file)
reference = gmm1.read_feature(reference_path)
reference = GMMStats.from_hdf5(reference_file)
assert projected.is_similar_to(reference)
......@@ -122,24 +124,23 @@ def test_enroll():
reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_enrolled.hdf5"
)
if regenerate_refs:
biometric_reference.save(reference_file)
gmm2 = GMMMachine.from_hdf5(reference_file, ubm=ubm)
gmm2 = gmm1.read_biometric_reference(reference_file)
assert biometric_reference.is_similar_to(gmm2)
def test_score():
gmm1 = GMM(number_of_gaussians=2)
gmm1.ubm = GMMMachine.from_hdf5(
gmm1.load_model(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_ubm.hdf5")
)
biometric_reference = GMMMachine.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_enrolled.hdf5"),
ubm=gmm1.ubm,
)
probe = gmm1.read_feature(
probe = GMMStats.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5")
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment