Skip to content
Snippets Groups Projects
Commit cf90a293 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'gmm-params' into 'master'

Adapt to `bob.learn.em` API changes

See merge request !28
parents 56bc342f 363378b5
No related branches found
No related tags found
1 merge request!28Adapt to `bob.learn.em` API changes
Pipeline #58619 passed
......@@ -14,6 +14,7 @@ import copy
import logging
from typing import Callable
from typing import Union
import dask
import dask.array as da
......@@ -23,10 +24,10 @@ from h5py import File as HDF5File
from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.learn.em.cluster import KMeansMachine
from bob.learn.em.mixture import GMMMachine
from bob.learn.em.mixture import GMMStats
from bob.learn.em.mixture import linear_scoring
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine
from bob.learn.em import linear_scoring
logger = logging.getLogger(__name__)
......@@ -49,19 +50,23 @@ class GMM(BioAlgorithm, BaseEstimator):
number_of_gaussians: int,
# parameters of UBM training
kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means
kmeans_init_iterations: Union[
int, None
] = None, # Maximum number of iterations for K-Means init
kmeans_oversampling_factor: int = 64,
ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training
training_threshold: float = 5e-4, # Threshold to end the ML training
variance_threshold: float = 5e-4, # Minimum value that a variance can reach
update_means: bool = True,
update_variances: bool = True,
update_weights: bool = True,
# parameters of the GMM enrollment
gmm_enroll_iterations: int = 1, # Number of iterations for the enrollment phase
# parameters of the GMM enrollment (MAP)
gmm_enroll_iterations: int = 1,
enroll_update_means: bool = True,
enroll_update_variances: bool = False,
enroll_update_weights: bool = False,
relevance_factor: float = 4, # Relevance factor as described in Reynolds paper
responsibility_threshold: float = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance.
enroll_relevance_factor: Union[float, None] = 4,
enroll_alpha: float = 0.5,
# scoring
scoring_function: Callable = linear_scoring,
# RNG
......@@ -75,6 +80,11 @@ class GMM(BioAlgorithm, BaseEstimator):
The number of Gaussians used in the UBM and the models.
kmeans_training_iterations
Number of e-m iterations to train k-means initializing the UBM.
kmeans_init_iterations
Number of iterations used for setting the k-means initial centroids.
if None, will use the same as kmeans_training_iterations.
kmeans_oversampling_factor
Oversampling factor used by k-means initializer.
ubm_training_iterations
Number of e-m iterations for training the UBM.
training_threshold
......@@ -95,12 +105,11 @@ class GMM(BioAlgorithm, BaseEstimator):
Decides wether the means of the Gaussians are updated while enrolling.
enroll_update_variances
Decides wether the variancess of the Gaussians are updated while enrolling.
relevance_factor
Relevance factor as described in Reynolds paper.
responsibility_threshold
If set, the weight of a particular Gaussian will at least be greater than
this threshold. In the case where the real weight is lower, the prior mean
value will be used to estimate the current mean and variance.
enroll_relevance_factor
For enrollment: MAP relevance factor as described in Reynolds paper.
If None, will not apply Reynolds adaptation.
enroll_alpha
For enrollment: MAP adaptation coefficient.
init_seed
Seed for the random number generation.
scoring_function
......@@ -110,20 +119,26 @@ class GMM(BioAlgorithm, BaseEstimator):
# Copy parameters
self.number_of_gaussians = number_of_gaussians
self.kmeans_training_iterations = kmeans_training_iterations
self.kmeans_init_iterations = (
kmeans_training_iterations
if kmeans_init_iterations is None
else kmeans_init_iterations
)
self.kmeans_oversampling_factor = kmeans_oversampling_factor
self.ubm_training_iterations = ubm_training_iterations
self.training_threshold = training_threshold
self.variance_threshold = variance_threshold
self.update_weights = update_weights
self.update_means = update_means
self.update_variances = update_variances
self.relevance_factor = relevance_factor
self.enroll_relevance_factor = enroll_relevance_factor
self.enroll_alpha = enroll_alpha
self.gmm_enroll_iterations = gmm_enroll_iterations
self.enroll_update_means = enroll_update_means
self.enroll_update_weights = enroll_update_weights
self.enroll_update_variances = enroll_update_variances
self.init_seed = init_seed
self.rng = self.init_seed
self.responsibility_threshold = responsibility_threshold
self.scoring_function = scoring_function
......@@ -186,7 +201,6 @@ class GMM(BioAlgorithm, BaseEstimator):
# Use the array to train a GMM and return it
logger.info("Enrolling with %d feature vectors", array.shape[0])
# TODO accept responsibility_threshold in bob.learn.em
with dask.config.set(scheduler="threads"):
gmm = GMMMachine(
n_gaussians=self.number_of_gaussians,
......@@ -199,6 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
relevance_factor=self.enroll_relevance_factor,
alpha=self.enroll_alpha,
)
gmm.fit(array)
return gmm
......@@ -304,7 +320,7 @@ class GMM(BioAlgorithm, BaseEstimator):
convergence_threshold=self.training_threshold,
max_iter=self.kmeans_training_iterations,
init_method="k-means||",
init_max_iter=5,
init_max_iter=self.kmeans_init_iterations,
random_state=self.init_seed,
),
)
......@@ -319,10 +335,21 @@ class GMM(BioAlgorithm, BaseEstimator):
def transform(self, X, **kwargs):
"""Passthrough. Enroll applies a different transform as score."""
# The idea would be to apply the projection in Transform (going from extracted
# to GMMStats), but we must not apply this during the training (fit requires
# extracted data directly).
# to GMMStats), but we must not apply this during the training or enrollment
# (those require extracted data directly, not projected).
# `project` is applied in the score function directly.
return X
@classmethod
def custom_enrolled_save_fn(cls, data, path):
data.save(path)
def custom_enrolled_load_fn(self, path):
return GMMMachine.from_hdf5(path, ubm=self.ubm)
def _more_tags(self):
return {"bob_fit_supports_dask_array": True}
return {
"bob_fit_supports_dask_array": True,
"bob_enrolled_save_fn": self.custom_enrolled_save_fn,
"bob_enrolled_load_fn": self.custom_enrolled_load_fn,
}
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -18,7 +18,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import os
import tempfile
import numpy
......@@ -28,8 +27,8 @@ import bob.bio.gmm
from bob.bio.base.test import utils
from bob.bio.gmm.algorithm import GMM
from bob.learn.em.mixture.gmm import GMMMachine
from bob.learn.em.mixture.gmm import GMMStats
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
logger = logging.getLogger(__name__)
......@@ -97,7 +96,7 @@ def test_projector():
# Generate and project random feature
feature = utils.random_array((20, 45), -5.0, 5.0, seed=seed_value)
projected = gmm1.project(feature)
assert isinstance(projected, bob.learn.em.mixture.GMMStats)
assert isinstance(projected, GMMStats)
reference_file = pkg_resources.resource_filename(
"bob.bio.gmm.test", "data/gmm_projected.hdf5"
......@@ -137,8 +136,8 @@ def test_enroll():
with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_bioref.hdf5") as fd:
temp_file = fd.name
gmm1.write_biometric_reference(biometric_reference, reference_file)
assert os.path.exists(temp_file)
gmm1.write_biometric_reference(biometric_reference, temp_file)
assert GMMMachine.from_hdf5(temp_file, ubm).is_similar_to(gmm2)
def test_score():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment