From 248cf61a8a39c5785e6d899c1d23a88fc9c6b33c Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Mon, 13 Dec 2021 11:09:16 +0100 Subject: [PATCH] Remove old algorithms --- bob/bio/gmm/__init__.py | 1 - bob/bio/gmm/algorithm/GMM.py | 486 ++++--- bob/bio/gmm/algorithm/ISV.py | 232 ---- bob/bio/gmm/algorithm/IVector.py | 412 ------ bob/bio/gmm/algorithm/JFA.py | 131 -- bob/bio/gmm/algorithm/__init__.py | 12 +- bob/bio/gmm/bioalgorithm/GMM.py | 322 ----- bob/bio/gmm/bioalgorithm/__init__.py | 23 - bob/bio/gmm/config/algorithm/gmm.py | 4 +- bob/bio/gmm/config/bioalgorithm/__init__.py | 0 bob/bio/gmm/config/bioalgorithm/gmm.py | 3 - .../gmm/config/bioalgorithm/gmm_regular.py | 5 - bob/bio/gmm/test/data/gmm_model.hdf5 | Bin 9984 -> 12920 bytes bob/bio/gmm/test/data/gmm_projected.hdf5 | Bin 6232 -> 10608 bytes bob/bio/gmm/test/data/gmm_projector.hdf5 | Bin 9984 -> 12920 bytes bob/bio/gmm/test/test_algorithms.py | 1161 +++++++++-------- setup.py | 4 - 17 files changed, 821 insertions(+), 1975 deletions(-) delete mode 100644 bob/bio/gmm/algorithm/ISV.py delete mode 100644 bob/bio/gmm/algorithm/IVector.py delete mode 100644 bob/bio/gmm/algorithm/JFA.py delete mode 100644 bob/bio/gmm/bioalgorithm/GMM.py delete mode 100644 bob/bio/gmm/bioalgorithm/__init__.py delete mode 100644 bob/bio/gmm/config/bioalgorithm/__init__.py delete mode 100644 bob/bio/gmm/config/bioalgorithm/gmm.py delete mode 100644 bob/bio/gmm/config/bioalgorithm/gmm_regular.py diff --git a/bob/bio/gmm/__init__.py b/bob/bio/gmm/__init__.py index 24f76d9..c020e48 100644 --- a/bob/bio/gmm/__init__.py +++ b/bob/bio/gmm/__init__.py @@ -1,5 +1,4 @@ from . import algorithm # noqa: F401 -from . import bioalgorithm # noqa: F401 from . import test # noqa: F401 diff --git a/bob/bio/gmm/algorithm/GMM.py b/bob/bio/gmm/algorithm/GMM.py index 747d45b..ed574e2 100644 --- a/bob/bio/gmm/algorithm/GMM.py +++ b/bob/bio/gmm/algorithm/GMM.py @@ -2,75 +2,110 @@ # vim: set fileencoding=utf-8 : # Manuel Guenther <Manuel.Guenther@idiap.ch> +"""Interface between the lower level GMM classes and the Algorithm Transformer. + +Implements the enroll and score methods using the low level GMM implementation. + +This adds the notions of models, probes, enrollment, and scores to GMM. +""" + import logging -from multiprocessing.pool import ThreadPool +from typing import Callable + +import dask.array as da +import numpy as np +import dask +from h5py import File as HDF5File + +from sklearn.base import BaseEstimator -import numpy +from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm +from bob.learn.em.mixture import GMMMachine +from bob.learn.em.mixture import GMMStats +from bob.learn.em.mixture import linear_scoring +from bob.pipelines.wrappers import DaskWrapper -import bob.core -import bob.io.base -import bob.learn.em +logger = logging.getLogger(__name__) -from bob.bio.base.algorithm import Algorithm +# from bob.pipelines import ToDaskBag # Used when switching from samples to da.Array -logger = logging.getLogger("bob.bio.gmm") +class GMM(BioAlgorithm, BaseEstimator): + """Algorithm for computing UBM and Gaussian Mixture Models of the features. -class GMM(Algorithm): - """Algorithm for computing Universal Background Models and Gaussian Mixture Models of the features. - Features must be normalized to zero mean and unit standard deviation.""" + Features must be normalized to zero mean and unit standard deviation. + + Models are MAP GMM machines trained from a UBM on the enrollment feature set. + + The UBM is a ML GMM machine trained on the training feature set. + + Probes are GMM statistics of features projected on the UBM. + """ def __init__( self, # parameters for the GMM - number_of_gaussians, + number_of_gaussians: int, # parameters of UBM training - kmeans_training_iterations=25, # Maximum number of iterations for K-Means - gmm_training_iterations=25, # Maximum number of iterations for ML GMM Training - training_threshold=5e-4, # Threshold to end the ML training - variance_threshold=5e-4, # Minimum value that a variance can reach - update_weights=True, - update_means=True, - update_variances=True, + kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means + ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training + training_threshold: float = 5e-4, # Threshold to end the ML training + variance_threshold: float = 5e-4, # Minimum value that a variance can reach + update_weights: bool = True, + update_means: bool = True, + update_variances: bool = True, # parameters of the GMM enrollment - relevance_factor=4, # Relevance factor as described in Reynolds paper - gmm_enroll_iterations=1, # Number of iterations for the enrollment phase - responsibility_threshold=0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance. - INIT_SEED=5489, + relevance_factor: float = 4, # Relevance factor as described in Reynolds paper + gmm_enroll_iterations: int = 1, # Number of iterations for the enrollment phase + responsibility_threshold: float = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance. + init_seed: int = 5489, # scoring - scoring_function=bob.learn.em.linear_scoring, - n_threads=None, + scoring_function: Callable = linear_scoring, + # n_threads=None, ): - """Initializes the local UBM-GMM tool chain with the given file selector object""" + """Initializes the local UBM-GMM tool chain. + + Parameters + ---------- + number_of_gaussians + The number of Gaussians used in the UBM and the models. + kmeans_training_iterations + Number of e-m iterations to train k-means initializing the UBM. + ubm_training_iterations + Number of e-m iterations for training the UBM. + training_threshold + Convergence threshold to halt the GMM training early. + variance_threshold + Minimum value a variance of the Gaussians can reach. + update_weights + Decides wether the weights of the Gaussians are updated while training. + update_means + Decides wether the means of the Gaussians are updated while training. + update_variances + Decides wether the variancess of the Gaussians are updated while training. + relevance_factor + Relevance factor as described in Reynolds paper. + gmm_enroll_iterations + Number of iterations for the MAP GMM used for enrollment. + responsibility_threshold + If set, the weight of a particular Gaussian will at least be greater than + this threshold. In the case where the real weight is lower, the prior mean + value will be used to estimate the current mean and variance. + init_seed + Seed for the random number generation. + scoring_function + Function returning a score from a model, a UBM, and a probe. + """ # call base class constructor and register that this tool performs projection - Algorithm.__init__( - self, - performs_projection=True, - use_projected_features_for_enrollment=False, - number_of_gaussians=number_of_gaussians, - kmeans_training_iterations=kmeans_training_iterations, - gmm_training_iterations=gmm_training_iterations, - training_threshold=training_threshold, - variance_threshold=variance_threshold, - update_weights=update_weights, - update_means=update_means, - update_variances=update_variances, - relevance_factor=relevance_factor, - gmm_enroll_iterations=gmm_enroll_iterations, - responsibility_threshold=responsibility_threshold, - INIT_SEED=INIT_SEED, - scoring_function=str(scoring_function), - multiple_model_scoring=None, - multiple_probe_scoring="average", - ) + # super().__init__(score_reduction_operation=??) # copy parameters - self.gaussians = number_of_gaussians + self.number_of_gaussians = number_of_gaussians self.kmeans_training_iterations = kmeans_training_iterations - self.gmm_training_iterations = gmm_training_iterations + self.ubm_training_iterations = ubm_training_iterations self.training_threshold = training_threshold self.variance_threshold = variance_threshold self.update_weights = update_weights @@ -78,261 +113,210 @@ class GMM(Algorithm): self.update_variances = update_variances self.relevance_factor = relevance_factor self.gmm_enroll_iterations = gmm_enroll_iterations - self.init_seed = INIT_SEED - self.rng = bob.core.random.mt19937(self.init_seed) + self.init_seed = init_seed + self.rng = self.init_seed # TODO verify if rng object needed self.responsibility_threshold = responsibility_threshold - self.scoring_function = scoring_function - self.n_threads = n_threads - self.pool = None + + def scoring_function_wrapped(*args, **kwargs): + with dask.config.set(scheduler="threads"): + return scoring_function(*args, **kwargs) + + self.scoring_function = scoring_function_wrapped self.ubm = None - self.kmeans_trainer = bob.learn.em.KMeansTrainer() - self.ubm_trainer = bob.learn.em.ML_GMMTrainer( - self.update_means, - self.update_variances, - self.update_weights, - self.responsibility_threshold, - ) + + super().__init__() def _check_feature(self, feature): """Checks that the features are appropriate""" if ( - not isinstance(feature, numpy.ndarray) + not isinstance(feature, np.ndarray) or feature.ndim != 2 - or feature.dtype != numpy.float64 + or feature.dtype != np.float64 ): - raise ValueError("The given feature is not appropriate") + raise ValueError(f"The given feature is not appropriate: \n{feature}") if self.ubm is not None and feature.shape[1] != self.ubm.shape[1]: raise ValueError( "The given feature is expected to have %d elements, but it has %d" % (self.ubm.shape[1], feature.shape[1]) ) - ####################################################### - # UBM training # - - def train_ubm(self, array): - - logger.debug(" .... Training with %d feature vectors", array.shape[0]) - if self.n_threads is not None: - self.pool = ThreadPool(self.n_threads) - - # Computes input size - input_size = array.shape[1] - - # Creates the machines (KMeans and GMM) - logger.debug(" .... Creating machines") - kmeans = bob.learn.em.KMeansMachine(self.gaussians, input_size) - self.ubm = bob.learn.em.GMMMachine(self.gaussians, input_size) - - # Trains using the KMeansTrainer - logger.info(" -> Training K-Means") - - # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. - self.rng = bob.core.random.mt19937(self.init_seed) - bob.learn.em.train( - self.kmeans_trainer, - kmeans, - array, - self.kmeans_training_iterations, - self.training_threshold, - rng=self.rng, - pool=self.pool, - ) - - variances, weights = kmeans.get_variances_and_weights_for_each_cluster(array) - means = kmeans.means - - # Initializes the GMM - self.ubm.means = means - self.ubm.variances = variances - self.ubm.weights = weights - self.ubm.set_variance_thresholds(self.variance_threshold) - - # Trains the GMM - logger.info(" -> Training GMM") - # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. - self.rng = bob.core.random.mt19937(self.init_seed) - bob.learn.em.train( - self.ubm_trainer, - self.ubm, - array, - self.gmm_training_iterations, - self.training_threshold, - rng=self.rng, - pool=self.pool, - ) - - def save_ubm(self, projector_file): - """Save projector to file""" + def save_ubm(self, ubm_file): + """Saves the projector to file""" # Saves the UBM to file - logger.debug(" .... Saving model to file '%s'", projector_file) + logger.debug("Saving model to file '%s'", ubm_file) + hdf5 = ( - projector_file - if isinstance(projector_file, bob.io.base.HDF5File) - else bob.io.base.HDF5File(projector_file, "w") + ubm_file + if isinstance(ubm_file, HDF5File) + else HDF5File(ubm_file, "w") ) self.ubm.save(hdf5) - def train_projector(self, train_features, projector_file): - """Computes the Universal Background Model from the training ("world") data""" - [self._check_feature(feature) for feature in train_features] - - logger.info( - " -> Training UBM model with %d training files", len(train_features) - ) - - # Loads the data into an array - array = numpy.vstack(train_features) - - self.train_ubm(array) - - self.save_ubm(projector_file) - - ####################################################### - # GMM training using UBM # - def load_ubm(self, ubm_file): - hdf5file = bob.io.base.HDF5File(ubm_file) + hdf5file = HDF5File(ubm_file) + logger.debug("Loading model from file '%s'", ubm_file) # read UBM - self.ubm = bob.learn.em.GMMMachine(hdf5file) - self.ubm.set_variance_thresholds(self.variance_threshold) + self.ubm = GMMMachine.from_hdf5(hdf5file) + self.ubm.variance_thresholds = self.variance_threshold - def load_projector(self, projector_file): - """Reads the UBM model from file""" - # read UBM - self.load_ubm(projector_file) - # prepare MAP_GMM_Trainer - kwargs = ( - dict( - mean_var_update_responsibilities_threshold=self.responsibility_threshold - ) - if self.responsibility_threshold > 0.0 - else dict() - ) - self.enroll_trainer = bob.learn.em.MAP_GMMTrainer( - self.ubm, - relevance_factor=self.relevance_factor, - update_means=True, - update_variances=False, - **kwargs - ) - self.rng = bob.core.random.mt19937(self.init_seed) - - def project_ubm(self, array): - logger.debug(" .... Projecting %d feature vectors" % array.shape[0]) + def project(self, array): + """Computes GMM statistics against a UBM, given a 2D array of feature vectors""" + self._check_feature(array) + logger.debug(" .... Projecting %d feature vectors", array.shape[0]) # Accumulates statistics - gmm_stats = bob.learn.em.GMMStats(self.ubm.shape[0], self.ubm.shape[1]) - self.ubm.acc_statistics(array, gmm_stats) + with dask.config.set(scheduler="threads"): + gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1]) + self.ubm.acc_statistics(array, gmm_stats) + gmm_stats.compute() # return the resulting statistics return gmm_stats - def project(self, feature): - """Computes GMM statistics against a UBM, given an input 2D numpy.ndarray of feature vectors""" - self._check_feature(feature) - return self.project_ubm(feature) - - def read_gmm_stats(self, gmm_stats_file): - """Reads GMM stats from file.""" - return bob.learn.em.GMMStats(bob.io.base.HDF5File(gmm_stats_file)) - def read_feature(self, feature_file): """Read the type of features that we require, namely GMM_Stats""" - return self.read_gmm_stats(feature_file) + return GMMStats.from_hdf5(HDF5File(feature_file)) + + def write_feature(self, feature, feature_file): + """Write the features (GMM_Stats)""" + return feature.save(feature_file) - def enroll_gmm(self, array): + def enroll(self, data): + """Enrolls a GMM using MAP adaptation, given a list of 2D np.ndarray's of feature vectors""" + [self._check_feature(feature) for feature in data] + array = np.vstack(data) + # Use the array to train a GMM and return it logger.debug(" .... Enrolling with %d feature vectors", array.shape[0]) - gmm = bob.learn.em.GMMMachine(self.ubm) - gmm.set_variance_thresholds(self.variance_threshold) - bob.learn.em.train( - self.enroll_trainer, - gmm, - array, - self.gmm_enroll_iterations, - self.training_threshold, - rng=self.rng, - pool=self.pool, - ) + # TODO responsibility_threshold + with dask.config.set(scheduler="threads"): + gmm = GMMMachine( + n_gaussians=self.number_of_gaussians, + trainer="map", + ubm=self.ubm, + convergence_threshold=self.training_threshold, + max_fitting_steps=self.gmm_enroll_iterations, + random_state=self.rng, + update_means=True, + update_variances=True, # TODO default? + update_weights=True, # TODO default? + ) + gmm.variance_thresholds = self.variance_threshold + gmm = gmm.fit(array) + # info = {k: type(v) for k, v in gmm.__dict__.items()} + # for k, v in gmm.gaussians_.__dict__.items(): + # info[k] = type(v) + # raise ValueError(str(info)) return gmm - def enroll(self, feature_arrays): - """Enrolls a GMM using MAP adaptation, given a list of 2D numpy.ndarray's of feature vectors""" - [self._check_feature(feature) for feature in feature_arrays] - array = numpy.vstack(feature_arrays) - # Use the array to train a GMM and return it - return self.enroll_gmm(array) - - ###################################################### - # Feature comparison # def read_model(self, model_file): """Reads the model, which is a GMM machine""" - return bob.learn.em.GMMMachine(bob.io.base.HDF5File(model_file)) + return GMMMachine.from_hdf5(HDF5File(model_file), ubm=self.ubm) + + def write_model(self, model, model_file): + """Write the features (GMM_Stats)""" + return model.save(model_file) + + def score(self, biometric_reference: GMMMachine, data: GMMStats): + """Computes the score for the given model and the given probe. - def score(self, model, probe): - """Computes the score for the given model and the given probe using the scoring function from the config file""" - assert isinstance(model, bob.learn.em.GMMMachine) - assert isinstance(probe, bob.learn.em.GMMStats) + Uses the scoring function passed during initialization. + + Parameters + ---------- + biometric_reference: + The model to score against. + data: + The probe data to compare to the model. + """ + + assert isinstance(biometric_reference, GMMMachine) + return self.scoring_function( + models_means=[biometric_reference], + ubm=self.ubm, + test_stats=data, + frame_length_normalization=True, + )[0, 0] + + def score_multiple_biometric_references( + self, biometric_references: "list[GMMMachine]", data: GMMStats + ): + """Computes the score between multiple models and one probe. + + Uses the scoring function passed during initialization. + + Parameters + ---------- + biometric_references: + The models to score against. + data: + The probe data to compare to the models. + """ + + assert isinstance(biometric_references[0], GMMMachine), type( + biometric_references[0] + ) + stats = self.project(data) return self.scoring_function( - [model], self.ubm, [probe], [], frame_length_normalisation=True - )[0][0] + models_means=biometric_references, + ubm=self.ubm, + test_stats=stats, + frame_length_normalization=True, + ) def score_for_multiple_probes(self, model, probes): """This function computes the score between the given model and several given probe files.""" - assert isinstance(model, bob.learn.em.GMMMachine) + assert isinstance(model, GMMMachine) for probe in probes: - assert isinstance(probe, bob.learn.em.GMMStats) + assert isinstance(probe, GMMStats) # logger.warn("Please verify that this function is correct") - return self.probe_fusion_function( + return ( self.scoring_function( - [model], self.ubm, probes, [], frame_length_normalisation=True + models_means=model.means, + ubm=self.ubm, + test_stats=probes, + frame_length_normalization=True, ) + .mean() + .reshape((-1,)) ) + def fit(self, X, y=None, **kwargs): + """Trains the UBM.""" -class GMMRegular(GMM): - """Algorithm for computing Universal Background Models and Gaussian Mixture Models of the features""" - - def __init__(self, **kwargs): - """Initializes the local UBM-GMM tool chain with the given file selector object""" - # logger.warn("This class must be checked. Please verify that I didn't do any mistake here. I had to rename 'train_projector' into a 'train_enroller'!") - # initialize the UBMGMM base class - GMM.__init__(self, **kwargs) - # register a different set of functions in the Tool base class - Algorithm.__init__( - self, requires_enroller_training=True, performs_projection=False - ) - - ####################################################### - # UBM training # - - def train_enroller(self, train_features, enroller_file): - """Computes the Universal Background Model from the training ("world") data""" - train_features = [feature for client in train_features for feature in client] - return self.train_projector(train_features, enroller_file) + # Stack all the samples in a 2D array of features + array = da.vstack(X) - ####################################################### - # GMM training using UBM # + logger.debug("UBM with %d feature vectors", array.shape[0]) - def load_enroller(self, enroller_file): - """Reads the UBM model from file""" - return self.load_projector(enroller_file) + logger.debug(f"Creating UBM machine with {self.number_of_gaussians} gaussians") - ###################################################### - # Feature comparison # - def score(self, model, probe): - """Computes the score for the given model and the given probe. - The score are Log-Likelihood. - Therefore, the log of the likelihood ratio is obtained by computing the following difference.""" - - assert isinstance(model, bob.learn.em.GMMMachine) - self._check_feature(probe) - score = sum( - model.log_likelihood(probe[i, :]) - self.ubm.log_likelihood(probe[i, :]) - for i in range(probe.shape[0]) + self.ubm = GMMMachine( + n_gaussians=self.number_of_gaussians, + trainer="ml", + max_fitting_steps=self.ubm_training_iterations, + convergence_threshold=self.training_threshold, + update_means=self.update_means, + update_variances=self.update_variances, + update_weights=self.update_weights, + # TODO more params? ) - return score / probe.shape[0] - def score_for_multiple_probes(self, model, probes): - raise NotImplementedError("Implement Me!") + # Trains the GMM + logger.info("Training UBM GMM") + # Resetting the pseudo random number generator so we can have the same initialization for serial and parallel execution. + # self.rng = bob.core.random.mt19937(self.init_seed) + self.ubm = self.ubm.fit(array) + + return self + + def transform(self, X, **kwargs): + """Passthrough. Enroll applies a different transform as score.""" + # The idea would be to apply the projection in Transform (going from extracted + # to GMMStats), but we must not apply this during the training (fit requires + # extracted data directly). + # `project` is applied in the score function directly. + return X + + def _more_tags(self): + return {"bob_fit_supports_dask_array": True} diff --git a/bob/bio/gmm/algorithm/ISV.py b/bob/bio/gmm/algorithm/ISV.py deleted file mode 100644 index 6a5666e..0000000 --- a/bob/bio/gmm/algorithm/ISV.py +++ /dev/null @@ -1,232 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# Manuel Guenther <Manuel.Guenther@idiap.ch> - -import logging - -import numpy - -import bob.core -import bob.io.base -import bob.learn.em - -from bob.bio.base.algorithm import Algorithm - -from .GMM import GMM - -logger = logging.getLogger("bob.bio.gmm") - - -class ISV(GMM): - """Tool for computing Unified Background Models and Gaussian Mixture Models of the features""" - - def __init__( - self, - # ISV training - subspace_dimension_of_u, # U subspace dimension - isv_training_iterations=10, # Number of EM iterations for the ISV training - # ISV enrollment - isv_enroll_iterations=1, # Number of iterations for the enrollment phase - multiple_probe_scoring=None, # scoring when multiple probe files are available - # parameters of the GMM - **kwargs - ): - """Initializes the local UBM-GMM tool with the given file selector object""" - # call base class constructor with its set of parameters - GMM.__init__(self, **kwargs) - - # call tool constructor to overwrite what was set before - Algorithm.__init__( - self, - performs_projection=True, - use_projected_features_for_enrollment=True, - requires_enroller_training=False, # not needed anymore because it's done while training the projector - split_training_features_by_client=True, - subspace_dimension_of_u=subspace_dimension_of_u, - isv_training_iterations=isv_training_iterations, - isv_enroll_iterations=isv_enroll_iterations, - multiple_model_scoring=None, - multiple_probe_scoring=multiple_probe_scoring, - **kwargs - ) - - self.subspace_dimension_of_u = subspace_dimension_of_u - self.isv_training_iterations = isv_training_iterations - self.isv_enroll_iterations = isv_enroll_iterations - self.isv_trainer = bob.learn.em.ISVTrainer(self.relevance_factor) - - def train_isv(self, data): - """Train the ISV model given a dataset""" - logger.info(" -> Training ISV enroller") - self.isvbase = bob.learn.em.ISVBase(self.ubm, self.subspace_dimension_of_u) - # train ISV model - # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. - self.rng = bob.core.random.mt19937(self.init_seed) - bob.learn.em.train( - self.isv_trainer, - self.isvbase, - data, - self.isv_training_iterations, - rng=self.rng, - ) - - def train_projector(self, train_features, projector_file): - """Train Projector and Enroller at the same time""" - [ - self._check_feature(feature) - for client in train_features - for feature in client - ] - - data1 = numpy.vstack(feature for client in train_features for feature in client) - self.train_ubm(data1) - # to save some memory, we might want to delete these data - del data1 - - # project training data - logger.info(" -> Projecting training data") - data = [ - [self.project_ubm(feature) for feature in client] - for client in train_features - ] - - # train ISV - self.train_isv(data) - - # Save the ISV base AND the UBM into the same file - self.save_projector(projector_file) - - def save_projector(self, projector_file): - """Save the GMM and the ISV model in the same HDF5 file""" - hdf5file = bob.io.base.HDF5File(projector_file, "w") - hdf5file.create_group("Projector") - hdf5file.cd("Projector") - self.ubm.save(hdf5file) - - hdf5file.cd("/") - hdf5file.create_group("Enroller") - hdf5file.cd("Enroller") - self.isvbase.save(hdf5file) - - def load_isv(self, isv_file): - hdf5file = bob.io.base.HDF5File(isv_file) - self.isvbase = bob.learn.em.ISVBase(hdf5file) - # add UBM model from base class - self.isvbase.ubm = self.ubm - - def load_projector(self, projector_file): - """Load the GMM and the ISV model from the same HDF5 file""" - hdf5file = bob.io.base.HDF5File(projector_file) - - # Load Projector - hdf5file.cd("/Projector") - self.load_ubm(hdf5file) - - # Load Enroller - hdf5file.cd("/Enroller") - self.load_isv(hdf5file) - - ####################################################### - # ISV training # - def project_isv(self, projected_ubm): - projected_isv = numpy.ndarray( - shape=(self.ubm.shape[0] * self.ubm.shape[1],), dtype=numpy.float64 - ) - model = bob.learn.em.ISVMachine(self.isvbase) - model.estimate_ux(projected_ubm, projected_isv) - return projected_isv - - def project(self, feature): - """Computes GMM statistics against a UBM, then corresponding Ux vector""" - self._check_feature(feature) - projected_ubm = GMM.project(self, feature) - projected_isv = self.project_isv(projected_ubm) - return [projected_ubm, projected_isv] - - ####################################################### - # ISV model enroll # - - def write_feature(self, data, feature_file): - gmmstats = data[0] - Ux = data[1] - hdf5file = ( - bob.io.base.HDF5File(feature_file, "w") - if isinstance(feature_file, str) - else feature_file - ) - hdf5file.create_group("gmmstats") - hdf5file.cd("gmmstats") - gmmstats.save(hdf5file) - hdf5file.cd("..") - hdf5file.set("Ux", Ux) - - def read_feature(self, feature_file): - """Read the type of features that we require, namely GMMStats""" - hdf5file = bob.io.base.HDF5File(feature_file) - hdf5file.cd("gmmstats") - gmmstats = bob.learn.em.GMMStats(hdf5file) - hdf5file.cd("..") - Ux = hdf5file.read("Ux") - return [gmmstats, Ux] - - def _check_projected(self, probe): - """Checks that the probe is of the desired type""" - assert isinstance(probe, (tuple, list)) - assert len(probe) == 2 - assert isinstance(probe[0], bob.learn.em.GMMStats) - assert ( - isinstance(probe[1], numpy.ndarray) - and probe[1].ndim == 1 - and probe[1].dtype == numpy.float64 - ) - - def enroll(self, enroll_features): - """Performs ISV enrollment""" - for feature in enroll_features: - self._check_projected(feature) - machine = bob.learn.em.ISVMachine(self.isvbase) - self.isv_trainer.enroll( - machine, [f[0] for f in enroll_features], self.isv_enroll_iterations - ) - # return the resulting gmm - return machine - - ###################################################### - # Feature comparison # - def read_model(self, model_file): - """Reads the ISV Machine that holds the model""" - machine = bob.learn.em.ISVMachine(bob.io.base.HDF5File(model_file)) - machine.isv_base = self.isvbase - return machine - - def score(self, model, probe): - """Computes the score for the given model and the given probe.""" - assert isinstance(model, bob.learn.em.ISVMachine) - self._check_projected(probe) - - gmmstats = probe[0] - Ux = probe[1] - return model.forward_ux(gmmstats, Ux) - - def score_for_multiple_probes(self, model, probes): - """This function computes the score between the given model and several given probe files.""" - assert isinstance(model, bob.learn.em.ISVMachine) - [self._check_projected(probe) for probe in probes] - if self.probe_fusion_function is not None: - # When a multiple probe fusion function is selected, use it - return Algorithm.score_for_multiple_probes(self, model, probes) - else: - # Otherwise: compute joint likelihood of all probe features - # create GMM statistics from first probe statistics - # import pdb; pdb.set_trace() - gmmstats_acc = bob.learn.em.GMMStats(probes[0][0]) - # gmmstats_acc = probes[0][0] - # add all other probe statistics - for i in range(1, len(probes)): - gmmstats_acc += probes[i][0] - # compute ISV score with the accumulated statistics - projected_isv_acc = numpy.ndarray( - shape=(self.ubm.shape[0] * self.ubm.shape[1],), dtype=numpy.float64 - ) - model.estimate_ux(gmmstats_acc, projected_isv_acc) - return model.forward_ux(gmmstats_acc, projected_isv_acc) diff --git a/bob/bio/gmm/algorithm/IVector.py b/bob/bio/gmm/algorithm/IVector.py deleted file mode 100644 index 94c6f4f..0000000 --- a/bob/bio/gmm/algorithm/IVector.py +++ /dev/null @@ -1,412 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# Laurent El Shafey <Laurent.El-Shafey@idiap.ch> - -import logging - -import numpy - -import bob.core -import bob.io.base -import bob.learn.em -import bob.learn.linear - -from bob.bio.base.algorithm import Algorithm - -from .GMM import GMM - -logger = logging.getLogger("bob.bio.gmm") - - -class IVector(GMM): - """Tool for extracting I-Vectors""" - - def __init__( - self, - # IVector training - subspace_dimension_of_t, # T subspace dimension - tv_training_iterations=25, # Number of EM iterations for the JFA training - update_sigma=True, - use_whitening=True, - use_lda=False, - use_wccn=False, - use_plda=False, - lda_dim=None, - lda_strip_to_rank=True, - plda_dim_F=50, - plda_dim_G=50, - plda_training_iterations=50, - # parameters of the GMM - **kwargs - ): - """Initializes the local GMM tool with the given file selector object""" - # call base class constructor with its set of parameters - GMM.__init__(self, **kwargs) - - # call tool constructor to overwrite what was set before - Algorithm.__init__( - self, - performs_projection=True, - use_projected_features_for_enrollment=True, - requires_enroller_training=False, # not needed anymore because it's done while training the projector - split_training_features_by_client=True, - subspace_dimension_of_t=subspace_dimension_of_t, - tv_training_iterations=tv_training_iterations, - update_sigma=update_sigma, - use_whitening=use_whitening, - use_lda=use_lda, - use_wccn=use_wccn, - use_plda=use_plda, - lda_dim=lda_dim, - lda_strip_to_rank=lda_strip_to_rank, - plda_dim_F=plda_dim_F, - plda_dim_G=plda_dim_G, - plda_training_iterations=plda_training_iterations, - multiple_model_scoring=None, - multiple_probe_scoring=None, - **kwargs - ) - - self.update_sigma = update_sigma - self.use_whitening = use_whitening - self.use_lda = use_lda - self.use_wccn = use_wccn - self.use_plda = use_plda - self.subspace_dimension_of_t = subspace_dimension_of_t - self.tv_training_iterations = tv_training_iterations - - self.ivector_trainer = bob.learn.em.IVectorTrainer(update_sigma=update_sigma) - self.whitening_trainer = bob.learn.linear.WhiteningTrainer() - - self.lda_dim = lda_dim - self.lda_trainer = bob.learn.linear.FisherLDATrainer( - strip_to_rank=lda_strip_to_rank - ) - self.wccn_trainer = bob.learn.linear.WCCNTrainer() - self.plda_trainer = bob.learn.em.PLDATrainer() - self.plda_dim_F = plda_dim_F - self.plda_dim_G = plda_dim_G - self.plda_training_iterations = plda_training_iterations - - def _check_ivector(self, feature): - """Checks that the features are appropriate""" - if ( - not isinstance(feature, numpy.ndarray) - or feature.ndim != 1 - or feature.dtype != numpy.float64 - ): - raise ValueError("The given feature is not appropriate") - - def train_ivector(self, training_stats): - logger.info(" -> Training IVector enroller") - self.tv = bob.learn.em.IVectorMachine( - self.ubm, self.subspace_dimension_of_t, self.variance_threshold - ) - - # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. - self.rng = bob.core.random.mt19937(self.init_seed) - - # train IVector model - bob.learn.em.train( - self.ivector_trainer, - self.tv, - training_stats, - self.tv_training_iterations, - rng=self.rng, - ) - - def train_whitener(self, training_features): - logger.info(" -> Training Whitening") - ivectors_matrix = numpy.vstack(training_features) - # create a Linear Machine - self.whitener = bob.learn.linear.Machine( - ivectors_matrix.shape[1], ivectors_matrix.shape[1] - ) - # create the whitening trainer - self.whitening_trainer.train(ivectors_matrix, self.whitener) - - def train_lda(self, training_features): - logger.info(" -> Training LDA projector") - self.lda, __eig_vals = self.lda_trainer.train(training_features) - - # resize the machine if desired - # You can only clip if the rank is higher than LDA_DIM - if self.lda_dim is not None: - if len(__eig_vals) < self.lda_dim: - logger.warning( - " -> You are resizing the LDA matrix to a value above its rank" - "(from {0} to {1}). Be aware that this may lead you to imprecise eigenvectors.".format( - len(__eig_vals), self.lda_dim - ) - ) - self.lda.resize(self.lda.shape[0], self.lda_dim) - - def train_wccn(self, training_features): - logger.info(" -> Training WCCN projector") - self.wccn = self.wccn_trainer.train(training_features) - - def train_plda(self, training_features): - logger.info(" -> Training PLDA projector") - self.plda_trainer.init_f_method = "BETWEEN_SCATTER" - self.plda_trainer.init_g_method = "WITHIN_SCATTER" - self.plda_trainer.init_sigma_method = "VARIANCE_DATA" - variance_flooring = 1e-5 - training_features = [numpy.vstack(client) for client in training_features] - input_dim = training_features[0].shape[1] - - # Reseting the pseudo random number generator so we can have the same initialization for serial and parallel execution. - self.rng = bob.core.random.mt19937(self.init_seed) - - self.plda_base = bob.learn.em.PLDABase( - input_dim, self.plda_dim_F, self.plda_dim_G, variance_flooring - ) - bob.learn.em.train( - self.plda_trainer, - self.plda_base, - training_features, - self.plda_training_iterations, - rng=self.rng, - ) - - def train_projector(self, train_features, projector_file): - """Train Projector and Enroller at the same time""" - - [ - self._check_feature(feature) - for client in train_features - for feature in client - ] - - # train UBM - data = numpy.vstack(feature for client in train_features for feature in client) - self.train_ubm(data) - del data - - # project training data - logger.info(" -> Projecting training data") - train_gmm_stats = [ - [self.project_ubm(feature) for feature in client] - for client in train_features - ] - train_gmm_stats_flatten = [ - stats for client in train_gmm_stats for stats in client - ] - - # train IVector - logger.info(" -> Projecting training data") - self.train_ivector(train_gmm_stats_flatten) - - # project training i-vectors - train_ivectors = [ - [self.project_ivector(stats) for stats in client] - for client in train_gmm_stats - ] - train_ivectors_flatten = [ - stats for client in train_ivectors for stats in client - ] - - if self.use_whitening: - # Train Whitening - self.train_whitener(train_ivectors_flatten) - # whitening and length-normalizing i-vectors - train_ivectors = [ - [self.project_whitening(ivec) for ivec in client] - for client in train_ivectors - ] - - if self.use_lda: - self.train_lda(train_ivectors) - train_ivectors = [ - [self.project_lda(ivec) for ivec in client] for client in train_ivectors - ] - - if self.use_wccn: - self.train_wccn(train_ivectors) - train_ivectors = [ - [self.project_wccn(ivec) for ivec in client] - for client in train_ivectors - ] - - if self.use_plda: - self.train_plda(train_ivectors) - - # save - self.save_projector(projector_file) - - def save_projector(self, projector_file): - # Save the IVector base AND the UBM AND the whitening into the same file - hdf5file = bob.io.base.HDF5File(projector_file, "w") - hdf5file.create_group("Projector") - hdf5file.cd("Projector") - self.save_ubm(hdf5file) - - hdf5file.cd("/") - hdf5file.create_group("Enroller") - hdf5file.cd("Enroller") - self.tv.save(hdf5file) - - if self.use_whitening: - hdf5file.cd("/") - hdf5file.create_group("Whitener") - hdf5file.cd("Whitener") - self.whitener.save(hdf5file) - - if self.use_lda: - hdf5file.cd("/") - hdf5file.create_group("LDA") - hdf5file.cd("LDA") - self.lda.save(hdf5file) - - if self.use_wccn: - hdf5file.cd("/") - hdf5file.create_group("WCCN") - hdf5file.cd("WCCN") - self.wccn.save(hdf5file) - - if self.use_plda: - hdf5file.cd("/") - hdf5file.create_group("PLDA") - hdf5file.cd("PLDA") - self.plda_base.save(hdf5file) - - def load_tv(self, tv_file): - hdf5file = bob.io.base.HDF5File(tv_file) - self.tv = bob.learn.em.IVectorMachine(hdf5file) - # add UBM model from base class - self.tv.ubm = self.ubm - - def load_whitener(self, whitening_file): - hdf5file = bob.io.base.HDF5File(whitening_file) - self.whitener = bob.learn.linear.Machine(hdf5file) - - def load_lda(self, lda_file): - hdf5file = bob.io.base.HDF5File(lda_file) - self.lda = bob.learn.linear.Machine(hdf5file) - - def load_wccn(self, wccn_file): - hdf5file = bob.io.base.HDF5File(wccn_file) - self.wccn = bob.learn.linear.Machine(hdf5file) - - def load_plda(self, plda_file): - hdf5file = bob.io.base.HDF5File(plda_file) - self.plda_base = bob.learn.em.PLDABase(hdf5file) - self.plda_machine = bob.learn.em.PLDAMachine(self.plda_base) - - def load_projector(self, projector_file): - """Load the GMM and the ISV model from the same HDF5 file""" - hdf5file = bob.io.base.HDF5File(projector_file) - - # Load Projector - hdf5file.cd("/Projector") - self.load_ubm(hdf5file) - - # Load Enroller - hdf5file.cd("/Enroller") - self.load_tv(hdf5file) - - if self.use_whitening: - # Load Whitening - hdf5file.cd("/Whitener") - self.load_whitener(hdf5file) - - if self.use_lda: - # Load LDA - hdf5file.cd("/LDA") - self.load_lda(hdf5file) - - if self.use_wccn: - # Load WCCN - hdf5file.cd("/WCCN") - self.load_wccn(hdf5file) - - if self.use_plda: - # Load PLDA - hdf5file.cd("/PLDA") - self.load_plda(hdf5file) - - def project_ivector(self, gmm_stats): - return self.tv.project(gmm_stats) - - def project_whitening(self, ivector): - whitened = self.whitener.forward(ivector) - return whitened / numpy.linalg.norm(whitened) - - def project_lda(self, ivector): - out_ivector = numpy.ndarray(self.lda.shape[1], numpy.float64) - self.lda(ivector, out_ivector) - return out_ivector - - def project_wccn(self, ivector): - out_ivector = numpy.ndarray(self.wccn.shape[1], numpy.float64) - self.wccn(ivector, out_ivector) - return out_ivector - - ####################################################### - # IVector projection # - def project(self, feature_array): - """Computes GMM statistics against a UBM, then corresponding Ux vector""" - self._check_feature(feature_array) - # project UBM - projected_ubm = self.project_ubm(feature_array) - # project I-Vector - ivector = self.project_ivector(projected_ubm) - # whiten I-Vector - if self.use_whitening: - ivector = self.project_whitening(ivector) - # LDA projection - if self.use_lda: - ivector = self.project_lda(ivector) - # WCCN projection - if self.use_wccn: - ivector = self.project_wccn(ivector) - return ivector - - ####################################################### - # Read / Write I-Vectors # - def write_feature(self, data, feature_file): - """Saves the feature, which is the (whitened) I-Vector.""" - bob.bio.base.save(data, feature_file) - - def read_feature(self, feature_file): - """Read the type of features that we require, namely i-vectors (stored as simple numpy arrays)""" - return bob.bio.base.load(feature_file) - - ####################################################### - # Model Enrollment # - def enroll(self, enroll_features): - """Performs IVector enrollment""" - [self._check_ivector(feature) for feature in enroll_features] - average_ivector = numpy.mean(numpy.vstack(enroll_features), axis=0) - if self.use_plda: - average_ivector = average_ivector.reshape(1, -1) - self.plda_trainer.enroll(self.plda_machine, average_ivector) - return self.plda_machine - else: - return average_ivector - - ###################################################### - # Feature comparison # - def read_model(self, model_file): - """Reads the whitened i-vector that holds the model""" - if self.use_plda: - return bob.learn.em.PLDAMachine( - bob.io.base.HDF5File(str(model_file)), self.plda_base - ) - else: - return bob.bio.base.load(model_file) - - def score(self, model, probe): - """Computes the score for the given model and the given probe.""" - self._check_ivector(probe) - if self.use_plda: - return model.log_likelihood_ratio(probe) - else: - self._check_ivector(model) - return numpy.dot( - model / numpy.linalg.norm(model), probe / numpy.linalg.norm(probe) - ) - - def score_for_multiple_probes(self, model, probes): - """This function computes the score between the given model and several given probe files.""" - probe = numpy.mean(numpy.vstack(probes), axis=0) - return self.score(model, probe) diff --git a/bob/bio/gmm/algorithm/JFA.py b/bob/bio/gmm/algorithm/JFA.py deleted file mode 100644 index 4280d26..0000000 --- a/bob/bio/gmm/algorithm/JFA.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# Manuel Guenther <Manuel.Guenther@idiap.ch> - -import logging - -import bob.core -import bob.io.base -import bob.learn.em - -from bob.bio.base.algorithm import Algorithm - -from .GMM import GMM - -logger = logging.getLogger("bob.bio.gmm") - - -class JFA(GMM): - """Tool for computing Unified Background Models and Gaussian Mixture Models of the features and project it via JFA""" - - def __init__( - self, - # JFA training - subspace_dimension_of_u, # U subspace dimension - subspace_dimension_of_v, # V subspace dimension - jfa_training_iterations=10, # Number of EM iterations for the JFA training - # JFA enrollment - jfa_enroll_iterations=1, # Number of iterations for the enrollment phase - # parameters of the GMM - **kwargs - ): - """Initializes the local UBM-GMM tool with the given file selector object""" - # call base class constructor - GMM.__init__(self, **kwargs) - - # call tool constructor to overwrite what was set before - Algorithm.__init__( - self, - performs_projection=True, - use_projected_features_for_enrollment=True, - requires_enroller_training=True, - subspace_dimension_of_u=subspace_dimension_of_u, - subspace_dimension_of_v=subspace_dimension_of_v, - jfa_training_iterations=jfa_training_iterations, - jfa_enroll_iterations=jfa_enroll_iterations, - multiple_model_scoring=None, - multiple_probe_scoring=None, - **kwargs - ) - - self.subspace_dimension_of_u = subspace_dimension_of_u - self.subspace_dimension_of_v = subspace_dimension_of_v - self.jfa_training_iterations = jfa_training_iterations - self.jfa_enroll_iterations = jfa_enroll_iterations - self.jfa_trainer = bob.learn.em.JFATrainer() - - def load_projector(self, projector_file): - """Reads the UBM model from file""" - # Here, we just need to load the UBM from the projector file. - self.load_ubm(projector_file) - - ####################################################### - # JFA training # - def train_enroller(self, train_features, enroller_file): - # assert that all training features are GMMStatistics - for client_feature in train_features: - for feature in client_feature: - assert isinstance(feature, bob.learn.em.GMMStats) - - # create a JFABasemachine with the UBM from the base class - self.jfa_base = bob.learn.em.JFABase( - self.ubm, self.subspace_dimension_of_u, self.subspace_dimension_of_v - ) - - # train the JFA - bob.learn.em.train_jfa( - self.jfa_trainer, - self.jfa_base, - train_features, - self.jfa_training_iterations, - rng=bob.core.random.mt19937(self.init_seed), - ) - - # Save the JFA base AND the UBM into the same file - self.jfa_base.save(bob.io.base.HDF5File(enroller_file, "w")) - - ####################################################### - # JFA model enroll # - def load_enroller(self, enroller_file): - """Reads the JFA base from file""" - # now, load the JFA base, if it is included in the file - self.jfa_base = bob.learn.em.JFABase(bob.io.base.HDF5File(enroller_file)) - # add UBM model from base class - self.jfa_base.ubm = self.ubm - - # TODO: Why is the rng re-initialized here? - # self.rng = bob.core.random.mt19937(self.init_seed) - - def read_feature(self, feature_file): - """Reads the projected feature to be enrolled as a model""" - return bob.learn.em.GMMStats(bob.io.base.HDF5File(feature_file)) - - def enroll(self, enroll_features): - """Enrolls a GMM using MAP adaptation""" - machine = bob.learn.em.JFAMachine(self.jfa_base) - self.jfa_trainer.enroll(machine, enroll_features, self.jfa_enroll_iterations) - # return the resulting gmm - return machine - - ###################################################### - # Feature comparison # - def read_model(self, model_file): - """Reads the JFA Machine that holds the model""" - machine = bob.learn.em.JFAMachine(bob.io.base.HDF5File(model_file)) - machine.jfa_base = self.jfa_base - return machine - - def score(self, model, probe): - """Computes the score for the given model and the given probe""" - assert isinstance(model, bob.learn.em.JFAMachine) - assert isinstance(probe, bob.learn.em.GMMStats) - return model.log_likelihood(probe) - - def score_for_multiple_probes(self, model, probes): - """This function computes the score between the given model and several probes.""" - # TODO: Check if this is correct - # logger.warn("This function needs to be verified!") - raise NotImplementedError("Multiple probes is not yet supported") - # scores = numpy.ndarray((len(probes),), 'float64') - # model.forward(probes, scores) - # return scores[0] diff --git a/bob/bio/gmm/algorithm/__init__.py b/bob/bio/gmm/algorithm/__init__.py index fc5f4fe..cf76a6b 100644 --- a/bob/bio/gmm/algorithm/__init__.py +++ b/bob/bio/gmm/algorithm/__init__.py @@ -1,8 +1,4 @@ -# from .GMM import GMM -# from .GMM import GMMRegular -from .ISV import ISV -from .IVector import IVector -from .JFA import JFA +from .GMM import GMM # gets sphinx autodoc done right - don't remove it @@ -22,10 +18,6 @@ def __appropriate__(*args): __appropriate__( - # GMM, - # GMMRegular, - JFA, - ISV, - IVector, + GMM, ) __all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/bio/gmm/bioalgorithm/GMM.py b/bob/bio/gmm/bioalgorithm/GMM.py deleted file mode 100644 index ed574e2..0000000 --- a/bob/bio/gmm/bioalgorithm/GMM.py +++ /dev/null @@ -1,322 +0,0 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# Manuel Guenther <Manuel.Guenther@idiap.ch> - -"""Interface between the lower level GMM classes and the Algorithm Transformer. - -Implements the enroll and score methods using the low level GMM implementation. - -This adds the notions of models, probes, enrollment, and scores to GMM. -""" - - -import logging - -from typing import Callable - -import dask.array as da -import numpy as np -import dask -from h5py import File as HDF5File - -from sklearn.base import BaseEstimator - -from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm -from bob.learn.em.mixture import GMMMachine -from bob.learn.em.mixture import GMMStats -from bob.learn.em.mixture import linear_scoring -from bob.pipelines.wrappers import DaskWrapper - -logger = logging.getLogger(__name__) - -# from bob.pipelines import ToDaskBag # Used when switching from samples to da.Array - - -class GMM(BioAlgorithm, BaseEstimator): - """Algorithm for computing UBM and Gaussian Mixture Models of the features. - - Features must be normalized to zero mean and unit standard deviation. - - Models are MAP GMM machines trained from a UBM on the enrollment feature set. - - The UBM is a ML GMM machine trained on the training feature set. - - Probes are GMM statistics of features projected on the UBM. - """ - - def __init__( - self, - # parameters for the GMM - number_of_gaussians: int, - # parameters of UBM training - kmeans_training_iterations: int = 25, # Maximum number of iterations for K-Means - ubm_training_iterations: int = 25, # Maximum number of iterations for GMM Training - training_threshold: float = 5e-4, # Threshold to end the ML training - variance_threshold: float = 5e-4, # Minimum value that a variance can reach - update_weights: bool = True, - update_means: bool = True, - update_variances: bool = True, - # parameters of the GMM enrollment - relevance_factor: float = 4, # Relevance factor as described in Reynolds paper - gmm_enroll_iterations: int = 1, # Number of iterations for the enrollment phase - responsibility_threshold: float = 0, # If set, the weight of a particular Gaussian will at least be greater than this threshold. In the case the real weight is lower, the prior mean value will be used to estimate the current mean and variance. - init_seed: int = 5489, - # scoring - scoring_function: Callable = linear_scoring, - # n_threads=None, - ): - """Initializes the local UBM-GMM tool chain. - - Parameters - ---------- - number_of_gaussians - The number of Gaussians used in the UBM and the models. - kmeans_training_iterations - Number of e-m iterations to train k-means initializing the UBM. - ubm_training_iterations - Number of e-m iterations for training the UBM. - training_threshold - Convergence threshold to halt the GMM training early. - variance_threshold - Minimum value a variance of the Gaussians can reach. - update_weights - Decides wether the weights of the Gaussians are updated while training. - update_means - Decides wether the means of the Gaussians are updated while training. - update_variances - Decides wether the variancess of the Gaussians are updated while training. - relevance_factor - Relevance factor as described in Reynolds paper. - gmm_enroll_iterations - Number of iterations for the MAP GMM used for enrollment. - responsibility_threshold - If set, the weight of a particular Gaussian will at least be greater than - this threshold. In the case where the real weight is lower, the prior mean - value will be used to estimate the current mean and variance. - init_seed - Seed for the random number generation. - scoring_function - Function returning a score from a model, a UBM, and a probe. - """ - - # call base class constructor and register that this tool performs projection - # super().__init__(score_reduction_operation=??) - - # copy parameters - self.number_of_gaussians = number_of_gaussians - self.kmeans_training_iterations = kmeans_training_iterations - self.ubm_training_iterations = ubm_training_iterations - self.training_threshold = training_threshold - self.variance_threshold = variance_threshold - self.update_weights = update_weights - self.update_means = update_means - self.update_variances = update_variances - self.relevance_factor = relevance_factor - self.gmm_enroll_iterations = gmm_enroll_iterations - self.init_seed = init_seed - self.rng = self.init_seed # TODO verify if rng object needed - self.responsibility_threshold = responsibility_threshold - - def scoring_function_wrapped(*args, **kwargs): - with dask.config.set(scheduler="threads"): - return scoring_function(*args, **kwargs) - - self.scoring_function = scoring_function_wrapped - - self.ubm = None - - super().__init__() - - def _check_feature(self, feature): - """Checks that the features are appropriate""" - if ( - not isinstance(feature, np.ndarray) - or feature.ndim != 2 - or feature.dtype != np.float64 - ): - raise ValueError(f"The given feature is not appropriate: \n{feature}") - if self.ubm is not None and feature.shape[1] != self.ubm.shape[1]: - raise ValueError( - "The given feature is expected to have %d elements, but it has %d" - % (self.ubm.shape[1], feature.shape[1]) - ) - - def save_ubm(self, ubm_file): - """Saves the projector to file""" - # Saves the UBM to file - logger.debug("Saving model to file '%s'", ubm_file) - - hdf5 = ( - ubm_file - if isinstance(ubm_file, HDF5File) - else HDF5File(ubm_file, "w") - ) - self.ubm.save(hdf5) - - def load_ubm(self, ubm_file): - hdf5file = HDF5File(ubm_file) - logger.debug("Loading model from file '%s'", ubm_file) - # read UBM - self.ubm = GMMMachine.from_hdf5(hdf5file) - self.ubm.variance_thresholds = self.variance_threshold - - def project(self, array): - """Computes GMM statistics against a UBM, given a 2D array of feature vectors""" - self._check_feature(array) - logger.debug(" .... Projecting %d feature vectors", array.shape[0]) - # Accumulates statistics - with dask.config.set(scheduler="threads"): - gmm_stats = GMMStats(self.ubm.shape[0], self.ubm.shape[1]) - self.ubm.acc_statistics(array, gmm_stats) - gmm_stats.compute() - - # return the resulting statistics - return gmm_stats - - def read_feature(self, feature_file): - """Read the type of features that we require, namely GMM_Stats""" - return GMMStats.from_hdf5(HDF5File(feature_file)) - - def write_feature(self, feature, feature_file): - """Write the features (GMM_Stats)""" - return feature.save(feature_file) - - def enroll(self, data): - """Enrolls a GMM using MAP adaptation, given a list of 2D np.ndarray's of feature vectors""" - [self._check_feature(feature) for feature in data] - array = np.vstack(data) - # Use the array to train a GMM and return it - logger.debug(" .... Enrolling with %d feature vectors", array.shape[0]) - - # TODO responsibility_threshold - with dask.config.set(scheduler="threads"): - gmm = GMMMachine( - n_gaussians=self.number_of_gaussians, - trainer="map", - ubm=self.ubm, - convergence_threshold=self.training_threshold, - max_fitting_steps=self.gmm_enroll_iterations, - random_state=self.rng, - update_means=True, - update_variances=True, # TODO default? - update_weights=True, # TODO default? - ) - gmm.variance_thresholds = self.variance_threshold - gmm = gmm.fit(array) - # info = {k: type(v) for k, v in gmm.__dict__.items()} - # for k, v in gmm.gaussians_.__dict__.items(): - # info[k] = type(v) - # raise ValueError(str(info)) - return gmm - - def read_model(self, model_file): - """Reads the model, which is a GMM machine""" - return GMMMachine.from_hdf5(HDF5File(model_file), ubm=self.ubm) - - def write_model(self, model, model_file): - """Write the features (GMM_Stats)""" - return model.save(model_file) - - def score(self, biometric_reference: GMMMachine, data: GMMStats): - """Computes the score for the given model and the given probe. - - Uses the scoring function passed during initialization. - - Parameters - ---------- - biometric_reference: - The model to score against. - data: - The probe data to compare to the model. - """ - - assert isinstance(biometric_reference, GMMMachine) - return self.scoring_function( - models_means=[biometric_reference], - ubm=self.ubm, - test_stats=data, - frame_length_normalization=True, - )[0, 0] - - def score_multiple_biometric_references( - self, biometric_references: "list[GMMMachine]", data: GMMStats - ): - """Computes the score between multiple models and one probe. - - Uses the scoring function passed during initialization. - - Parameters - ---------- - biometric_references: - The models to score against. - data: - The probe data to compare to the models. - """ - - assert isinstance(biometric_references[0], GMMMachine), type( - biometric_references[0] - ) - stats = self.project(data) - return self.scoring_function( - models_means=biometric_references, - ubm=self.ubm, - test_stats=stats, - frame_length_normalization=True, - ) - - def score_for_multiple_probes(self, model, probes): - """This function computes the score between the given model and several given probe files.""" - assert isinstance(model, GMMMachine) - for probe in probes: - assert isinstance(probe, GMMStats) - # logger.warn("Please verify that this function is correct") - return ( - self.scoring_function( - models_means=model.means, - ubm=self.ubm, - test_stats=probes, - frame_length_normalization=True, - ) - .mean() - .reshape((-1,)) - ) - - def fit(self, X, y=None, **kwargs): - """Trains the UBM.""" - - # Stack all the samples in a 2D array of features - array = da.vstack(X) - - logger.debug("UBM with %d feature vectors", array.shape[0]) - - logger.debug(f"Creating UBM machine with {self.number_of_gaussians} gaussians") - - self.ubm = GMMMachine( - n_gaussians=self.number_of_gaussians, - trainer="ml", - max_fitting_steps=self.ubm_training_iterations, - convergence_threshold=self.training_threshold, - update_means=self.update_means, - update_variances=self.update_variances, - update_weights=self.update_weights, - # TODO more params? - ) - - # Trains the GMM - logger.info("Training UBM GMM") - # Resetting the pseudo random number generator so we can have the same initialization for serial and parallel execution. - # self.rng = bob.core.random.mt19937(self.init_seed) - self.ubm = self.ubm.fit(array) - - return self - - def transform(self, X, **kwargs): - """Passthrough. Enroll applies a different transform as score.""" - # The idea would be to apply the projection in Transform (going from extracted - # to GMMStats), but we must not apply this during the training (fit requires - # extracted data directly). - # `project` is applied in the score function directly. - return X - - def _more_tags(self): - return {"bob_fit_supports_dask_array": True} diff --git a/bob/bio/gmm/bioalgorithm/__init__.py b/bob/bio/gmm/bioalgorithm/__init__.py deleted file mode 100644 index cf76a6b..0000000 --- a/bob/bio/gmm/bioalgorithm/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .GMM import GMM - - -# gets sphinx autodoc done right - don't remove it -def __appropriate__(*args): - """Says object was actually declared here, and not in the import module. - Fixing sphinx warnings of not being able to find classes, when path is shortened. - Parameters: - - *args: An iterable of objects to modify - - Resolves `Sphinx referencing issues - <https://github.com/sphinx-doc/sphinx/issues/3048>` - """ - - for obj in args: - obj.__module__ = __name__ - - -__appropriate__( - GMM, -) -__all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/bio/gmm/config/algorithm/gmm.py b/bob/bio/gmm/config/algorithm/gmm.py index 9b280de..ce235b6 100644 --- a/bob/bio/gmm/config/algorithm/gmm.py +++ b/bob/bio/gmm/config/algorithm/gmm.py @@ -1,5 +1,3 @@ import bob.bio.gmm -algorithm = bob.bio.gmm.algorithm.GMM( - number_of_gaussians=512, -) +algorithm = bob.bio.gmm.algorithm.GMM(number_of_gaussians=512) diff --git a/bob/bio/gmm/config/bioalgorithm/__init__.py b/bob/bio/gmm/config/bioalgorithm/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bob/bio/gmm/config/bioalgorithm/gmm.py b/bob/bio/gmm/config/bioalgorithm/gmm.py deleted file mode 100644 index 58aeddd..0000000 --- a/bob/bio/gmm/config/bioalgorithm/gmm.py +++ /dev/null @@ -1,3 +0,0 @@ -import bob.bio.gmm - -bioalgorithm = bob.bio.gmm.bioalgorithm.GMM(number_of_gaussians=512) diff --git a/bob/bio/gmm/config/bioalgorithm/gmm_regular.py b/bob/bio/gmm/config/bioalgorithm/gmm_regular.py deleted file mode 100644 index f7166b5..0000000 --- a/bob/bio/gmm/config/bioalgorithm/gmm_regular.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python - -import bob.bio.gmm - -bioalgorithm = bob.bio.gmm.bioalgorithm.GMMRegular(number_of_gaussians=512) diff --git a/bob/bio/gmm/test/data/gmm_model.hdf5 b/bob/bio/gmm/test/data/gmm_model.hdf5 index 21bf67f4cc54782e78f4b9750e15fefc76089089..07c92210b263582fb47143438de7fe31564bf091 100644 GIT binary patch delta 2493 zcmZqh`;jt1gQ>!3qE;W9h6V!z1H)ti_74*;{F*Ev$T4{XlK@kL*2E||t_Rv6IR*v` z$%zNeCMz%td2)aS7{Me112cmJ1Bflj$jBhTzyQV|J_7?p9V-I|Se}UqOfhgUNPu-R zGcru}7Exp__gyu4KJ!AR7m^bX%1usSk&r^z2v*3D0ihUvNP$#rV7cW5F&SbDB`y$< znCv5>0CvKD`NfPJlUK9ZGbu<+e#54~Uhcapk%3{t#Djl5m_gox*a`N7yR*L!*m976 z85qD~AUA^4foMZL0}#c)1ST06bfNU*L=GhbUtizE<c!R`RCPNBklP`0Fs-?X1rPx! z*}%W~qttiCP0}xL;)xg31wl3lKwSv3m1%Mzt0dKeNkai1O!WdYB_p72vw+eOl2FFv zg-Z5}3KI|dgOUj&$K=O??o1X6pwL7MVh6AwB8aC>ejuF1<bklpUuZHT$K+}WO&&K# zpI}!828NKJP}j*DB$SvU6ej10LgEaMsz_v2jT0|?s0W2GEQmntKPx|8@=bEE2S+3W z1B2X^8#nhlKe0a&?&s4|aowJgk#U8hhgRg|RD0f<M-vS`KCqWsy*PR~%T@b1=Qh9f zcb>8T!8g9L)-MnDr<}RAsE6~(ehu}r-i%e(_D}Hlxb-CFf&Hdjo-K@F_w4sFzd2q1 z<M|T%m0v2qZ@jQ+zv2AlbAt-c?hn}7Tl@aX!~Lc8oJN{oZ`m*5akA#Kzh?jOjMaym zyYATAZO?l9_xU+{*=?&8d=+=>fA;>5jWO?){cbV!FP20-vG=W+Egx(D*<LofUu#v; z<NZOV;m?9soZElK@s9HH*<0*`UJG*Zt=wk+Y+i8vjFXnf?CV>d=9Sr9-#<b9OLd~| zZTr8cV|TIG9kLIdx%OB_>(%{HkL;UQO@6mu+TeLYlk_orwfpIUXD6K7Z#qe^uUF*J z{+Wdt+sm2`?4KHW^@HDwyY~Mpo}5ZHIB7pQEARA7gZK6>4*tbzjmP)L$OZV_D15hH zSH1g`yKv|J=hNx~TqPSG@3)#6f2w%bh5bxH%o8%Qj@Y+7NniG1=l1;zErn91n!m7D zyRc-sgy<UkjRE_r1$!^;pSMB5e8%NB``OmH#}}7g+JD&Ws7v1Bd-h>H#a$U29@rPm z(>`3@wbI^y>cxQH7cTE#a#~sVd+r%~@yk~0wR6wykDA1xcK!Tq`)&1q?_5@wda?hA zZt+6*ulwwa_jOjCUHfi-(PA5O`>9X%@AEn;b#LB{{mnPtDBJYB-S7Ub$j)@vvHhzL zA64m}dUpS*{y9RQHa)UecI`jmD|W~JQSU4lpRg15TTQR;5)9k6zxQ1NTiJ~R`{k{( z`BG#L+V8KJWU?#$qJ7@qD>vEBp4?x*SwH(L>yE4T*)!xqqU(;?a|Bdo=LdbU&;2-K zS&`!A{d=qhG9RWKwii)%((5p6u)q6VJu>0sLHjK-r!-D=9pBIPw9jtchg0_dy!E&L zT>HSjW$%TE?+^R!y@hjS1nl4M_mK0qVL116zm4VW!VCA`?hj)>>U-Muk^QM9Q|`J> zd01~B>naj+MDnVA_{aLz@ML9&#kQP(me(pd*lp!lzi*9#gXoFr+o#@_aj=~xKdr|} z$-yS7FgmzU-XTuBHamI0ti$Pb_E)11C_407t>d!bRdVP|_TDMYsNhhyd+SVJW<`h2 zMQ>)NSjjkC)cyS4!b0A`YSYxYk24e;9&6Mt{PMY8!NJ(-#wwH33J#uMo<%IXBkyoX z<Nuq4Pz8sQBkrdre3WtcW$k;&$4ACt>LyL0HU~uqvv~gmKM@57uE~Cr^cO2QFdg00 z`Eiqi!}oamj@da14yPD(C1vKxIlORO_w3;gC5JL`$qwViiVjN4H@nz;l6P<jm=usQ zN6tZZ_vWI^`gPI{+L>RMp1m*U@MZm}6Nhx=9TwzXXju6|(P6&PA^#jEC5Op}ZrzUf zBI$7bS@DTdS2>4+j|%qOpDg3>ul~A8IiIqF$w|W{rt(q_+09qmbj{=(j$J)m@Zz$( z!=oo>-uehBIoMC$!LY_s(LrW;{xt3F3Jx|NlI?=6$_~?9o6gn0j#hHe-tvd{udKYo zjaRp?9GEBOuz8Eo>{oG04l3LCdMe#ka5yp}(Jc9(jKeM=o->EMlpQ`c_vW2aRd86j za@C%LpA;P0zDd{j*~&NwPRSSgwp8AMJ$>B-t%*twLT@)0h*c^%WL27p8{5h`Y&<*Z zWXCoo2Z{OeZOuIL4lm>u*Jo86Rdm=k@7QabQwk1e8fW^7c1b$;^(G3gWL0q3_)mP1 zke8yvy2%y4ww5S6?3CIf6fC0Pu%S2e=9W$chuP_hjrScT9kMQ_x32iF?C|5ev+vm_ z3J&*<DVr8*D>^Ly($~G~g_6VfqS%!yd8HiqR@Ekkgep03F1gLQxKP$Xe^px0+m>rG z4l5f(m7OFM9dd0~JXkzc-of;G4{uSCyhHR#oqsd*6&=d@<)$wSlXduVtTR36jF>|q z*QNL7y0Q*CBHzC~mmufxujbH!DO2Sg3K?6sGM<%oP+Nc9y);wF0aluWYs-lj^d<|4 zTQF)&J}55F)d8)YBxFI&hY8{wOd4_%56W?EfQlDrg2W$)b1-sDHk6Z}yg@VxSKSkZ zR;zGqE|9EX<P^|{7$vazqZB98<Rpa*rf7x9XA~f{)O@jW9%g7`h=GAY4`f?_90yZ| nA;=ZbCImx-!Nh~lCI=`~F=;Eqwf2iPL#xr>s_<%bKJ!EX$oq!o literal 9984 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo0fRb32+I8r;W02IKpBisx&unDV1h6h89<PM zK?1^M5QLhKt}Z0V)s=yPi2-IljD~7sFkpeO6d)9XfH;KF0HH?7VIBe=u8sj9FD5_} z+654Yfq@}I62btbGLRH{Iy)f+krxnz(4>-#3~US_b0Gu+10ysgGctljVZv~lnSle$ z7G_{zU}j)oV1@FTn3%vCIH3BO85uYrZUL!+ut1~`R2ZZ<cyCc60|P9lfMf?A2mATE zFfy<*fFpx}f#Cpwa6JHZ@(Fl(01-fsmkfw7!v!HI0|x~SGzJCKV8c)iD<>*gF{B2C z043>47jL>^h6WKx6~e27JzbF=9s=NeCcwb}%`pyO&*IK+2GDSUQef2#44C0EvT77V zJp-HMoW$Z{1@)x-B)y!}#G*XC)LgyX%!-oIqEx-~++00(UtizE<c!R`RCPOOHfG38 zEcmnX<0aoD2YXOa1q#i$(6ayO*$?d9LWLO(H1FFpGBOfndSYpDab{v(vAUgQ0i(8C z=Vg1oxie0@5k6{f5)s28$9!Tx%hG2Tzka;F-^508zjE!{{WD`92CMg<-hWPdq3BGB zYy01Rd>?i~@U;Ek*As6T^q#Ws3)YEWJN1(NqAOa<r+rzmzb-cKtkr>=`%CPt`6|2K z@2{Bfw6NyeHTw$XHqpOMq4p~_I!L#kIBg%XL2S=cvrF~|U3Mi@S}(NU?D<*g{no?# z3ybCacmBF+ztP53*5}7xdvo`W6C8WC*|(Mz8sB@lZT~g3&Z=GeHtyfw#<e71!NL7I zyWOP|WVhQ--DjC_L+-);eQmL|x89wx?>dqDEbGIi{rxwU>}0LC?>}GcGI2k{ru_@n zZ|kUM-L`*=l%&y2$1D2{PF^ic3cI{t_C<(}64TlJH5%@D$4wsE|9&YO#KJpye^4{K zLid43_K5*eOH>oC?r&Uv^!mZH$NPVt`Qdni>EV99yFy`>(jWGp7EV!g-*{+0o3e%0 zDeJrY{~G@l+Pdh9z32kZtEa!6x9{^adMI-Bq5abT0fkYLx9k<2ny=*ly1aj<!`fZD zYp2+=*C}YVI-T7gTcu=qnc=ei6c#D&{j=`bC-|<uvn=AM{Tat}@87c9_TSZFv9+6X zX8+;;>Q)!--?8_vtMqOUxv;-rj+T7Tt_$|NrMwaw-(A>$?cbNEo1gaYuaI7#HfzJh z{Vx>Tf|5gC?>Ftb5&zis#(qba-5<7=U$WnpdgO)Ss)hF4*179V=RVxeo~p&nZo6s! zkIbx)XofxZtsGwunXfrwf8vMWc2(_j_DQL4w_e(K!9MF;^pa&vN9_AIWNtem^1}X) zcYcZdms9&|z6D%SpK@mZ@i+YwPIg_e@ANn)9P#s@y*2C41ES@(_E)vf<9BsBX75{f z`r^OS+xK^Um(K|oy<;Ey>-H<*<sa-@mKy~p-`Hn=q<=+zapc$i0?U0LUT&Ocuka_q ziuLmC{qGx0#oj+TXTSa=zsTNcAMGc(=YOf4acaMPV8Lm7Ulj+pfOVSNl$0F4hfT;{ z(Wl@r_e;X|fHiUs7YuJ!`zFgfY)boo??$$Q!|yuVl*rQ(4i`VniBVxua+o3XzUv9E zlEY^G?JOb=iVn$d^-iT6ly?y7cp7@|ldOZi`1Cux>t!8&KlavWaaMF#V*K-W(F{e0 zbo*$v8$n7ARhipT{n!*8Zp-=4vx--AC_ho88h%>dq3!F(CwIkV9PUXLCo5l(cDQSA zKa=OWqQmEX0!|((iVl5IM)4gIiVh00n^YFcC^=YvILdMFm#o8eha&YVL1l-QGBL%C zEs732xmQyAViX(<{#)&R?Je(ct5Wf5XSRaFQ7P+}3tz}Om{c7Mm8+6-aNoazOR-hK z;eSrS40km}2kzY0V!@2k4ziZQCqBq2IMgmMbIvZ5cX-tGdV%R8B?q&(vpb&#%Qz@# zRTsVZA?5ISR%=kSx159Wze4@)1_g&N%E?j@+_DZXd*u|?|59|w`sUthBB|`KrZi~l zzvD^{5_eK=`973)C<stnD6v7@A=lyfu?G?=4uz$RN5h2`9agbhJ~GHxbYPH-Pq#au z<e)dtWx-T;MTgaU=N$hvNzUPZzP$PcCIyEK)<gY&1Y{hnrkk$X!=vO-bJpj-otC0Q zr`n`@Wrr0ULLOgx$$wqZ;j!iM`wV=t4qj)LUG!H|aOnES)^NaA*5S@$$=$neOFL{k zVb31st>EA)Upil<OWwicPpr(N93=<Cdw1mPX3IOAN{fHf93ktVz!5#G$6VTB@8el> zEVz^%x|TnGq%=*@!C}9|^c{6_4zJB;h?>4sba?r{_p!x9c?VVJe{8&>N)GzlKj(fg zlyk6ny~F({zoNs2vK7x1MU@;PyKH7Rhs!$bnK;XB_9hvJHvjt`{Lf__{z|=(d+}4= zA^A-E=U*}+4rL;5ZoX|)aQJ^mMw9umyhF?Ul$D|VN)Df<$5%gRlXJKzw~4{MUI9|W zf{OZ4JX}J6^mZ4q{TmBtLkqQ`j@jPo7*s(C>t}9&x(e1WDNuq63^I!J_>%y)n*|l@ z;q7J<ZU&S8&}hQlezu0jCzOJQ17>_yKxGDoB0W4v?MFC60{}`vy@?qPgQFioC_G^8 zc^J*iKyLc*h6V`g_#tMv47K#30nRTHOrWtJ1_n@n6Vl<~U`R{MDNbczU??dn1yNuH zxbsgmG<+bsp~(hI{}Wc?z*NwdCX_y4u7%Oe4Acsbm_ZpHknt@x1`BAgcd$VjI4BR8 zIBK$oiNa}EE%!kLE`*ga112g4V_>IY^-~0Fz!SSN23RTd05l+pLk%s22QFR34$5=| z3*+3>lEnD5%$(HtvecsD%=|p41X#ShD6=HBC>}f>0~ODUPlt_<fJI7*5;OBsix`sg z^U6|-(o^%2Q{ziAic*U+@^exc7#MOBE8^2KOG+~H(&LLuQVWW~x)>PBQ!~>uN{Sgu z3sMqGQsZ+|6Z61QP<~lrQD$OZaw=F1qy;Jt)5d`63UE;eQpPX=Ob$E&nLlF0Ie+v+ z7UIUyJ`5xrMjA<ac|>eKCk|TPKpW7|fE9wau0i$YV4lZ;)pG%`2@|OM(J81W8D7Ys z%b@aM#yH5Milg&j^G6EG=rX8$Sbj}VL=_(_JWA4)4&Hg3L}-wMR3Usj*wYnecxZqf z&j^|RWPtTEV6#)`eDeDrz0mNw05cM7FJ^cRbo-iK;StLX4giM16&?^bF))DJJ0!G& zU=0@Zh6;?2Xaqr;Frac5WEZ@B4&so32QEDus6d=MxZ)2s={HJ4CO=2RVQ7Z~0I8>G AEC2ui diff --git a/bob/bio/gmm/test/data/gmm_projected.hdf5 b/bob/bio/gmm/test/data/gmm_projected.hdf5 index c5b7f29da7cabbbf870f7ce48365e157e636e8b1..3866125f2b150cc3fb972a4317cef4343644f5e4 100644 GIT binary patch delta 2093 zcmca%@F8e|22+9NM6EtH4Gjhc28PK3>>nmx_%(S2lLV88)Wm~vlNFeSJUKv842)os zfq|Jp0?d|VWMmLvU;tx?97G)}0|!{1i3v<Ga4<-Kbuu$DO!gL0WG?qzHF-YsLZ%;5 z6A#KwPGI3+vXGp3P;PSv3lpOk#0ZFSlvpJoG1*5%0czJokS`cHCa-3*XHt-u{Dw_| zz1(+IA_K#Oi3k6BFoXO6u@~%McV~YeumeDzV_*P_fjj_G2cix23_uhE6PRRR(1y~J znK`uip=<~>gMTw4-+#u<0(>9v67LxWL52uG9RV_!X>uW}B-F_o3aCz=3iXb$pgWTR zSPtgTi5EB~YY6Rv>azgLA-p<u@&n;4CI<z0$mEOsW(rW4oFfVesD9DcObW_K+`@?$ zKGf?%{KBw%t6gH9-@*eTQ2zDgMC&3sbqD=9^N#dBRB)))Ua(#}K<0qRy1U!TOAHUZ znt!66%l)gpZo|DD*E~cVY(-@zWwB^GJSckGF#nA7frT4Qy~JiJ9(Z?o>u)p0Kl|IR z9N;UHV>s|CxT#lPSmeMi`%bfiVj}epPG9DKkCISu=yCrOrS?zNK_$|-YEH1~0p=qy zb0%`iJ21UF^i1Zhs)OE%)Tnnkq7EXrd5=62Q*#hK{dUC%S2KtBpyy6|Ed>r#Y}hq5 z@CD<68U-)e*{Tu;cKs09ow9@ZK>3d8|6bo$cVKmpF;+XL<G_@p_a*NYm&3hJcFw}e z^?&Rq?cTGzafP_UvSsh|<bVFRKg{fvb>$k{0TGrDKN;g>4p^r?c+qF6ePEfPs`j2P z@dMA#U6p4jlsT|^u5exiC#OSX;o+cDz8Vfod+xBtA5%YIoaL^!N1p3|tmNNE>MbG% zO0E1FFNT>Om=dmZZO0E02ev5>rd|1{=5XPF&Wl;~&ra>n-|8@}|EJi2ZI2vG`Q4Qr zxOY^0Uw*1^pf5^5kgZ$sz?S9TH*FXFzF&2{@OHi$g#){8?mn7e!|dP?kiA1cPS8Q4 z$gcl>+ROdfD;LN$=&Bt^>8)5=T_$=U$zgi&g-VqJd#=Bh&)uQmQ2cl8|5zS*hlcat z@@2QMI(Qn#-me!k{$_7^{B+=LHsu3<Ja(+;n#pls-ybE9+wV9Ii0ppwE%un?0S3+W zl9N6k*l!SKRihCk?{M+Y#$G-t-2>r_EPI}7eBD1|=cP*y>%Q9GJIE%W!>aC(v!L&n zbHHQ!qe*_e68AV9Vm|FN(mBuJ@a7zQ%eUuD2hR1JQ0+M{=1{(DOa1dFb|3bq?whcl ztzE=n!Q)T!HJq>5%XD0rFq!Fwz5EWwbwTr44-^D#c@uwE(qYxL;AY);F$aImw~lUM zA_wyJ*u6YaDR4ks<;DYxcT5N3e`-n|sZ%(xHazGd&mx5badR^&TtXQfM2^dBK45#z zUP&;`VUO<j{Y`t$uGpr*c0jJG{(Ir^zhMrxDZTe@x<xun%`HeOD+qO1y!7G`6Xg(x zSNtWj3~C}AQeE1-Ejtq&^m^wVj659T;CgSyf2)=thky-=mpA_MbohTXlI@C7q{Bj~ z%+klZyc`&dTUVZq_jB0$wRC~ep=bxA7vEl1)kZj|GhFgeJ|5_x{OwJj|EXYyFZJtB za+R3`IczV_yLLc4%HcWVy!C1ti4Occ{jN)+A{;jCRo-CuC&b~h!*ikJbKwq)&L67b zyqxH;YhPW;7Tb6S*^tKU#v{=V8(R7t1aE~oOkwQuTC&*R;cL({-gdtPhc*B2SxtT( z;qZhjbc44~l!IvfkCJESeI4%e{jgcC9Oh78T<k1i5gOp2@T=_v>(mH`U)2pN$5^5r za$bKlw=?l~IIUyMJSQR4VN%D&!o`=p9M;d14UHBKa*%$yGJW5oIEM>YI@8vK#W~0~ z8il|89qVBHP`k@^MwG+m2z{+Y<`4&8`>wrJQNa#%#oM3wvIjUEk#Ak6yUx?$-yc@i z*z|yUhre2%IlSWo9WobmDjI5fIhYHFZr5fHcc`{0`w*rQ;Bc9fPn1nC)Zx#yg_>(q z>>M1EPx+os4{|tP`G{rm*8qn*xe6<T0|FfWWnS{#lojkCB;2*_iB6z{*_ml3#V33m zcs`lBy6^@$Xl|}M=DFC<VV?bCm;RYS4in@~9Nw;7ALvjQpe6fdeSkyA43kpdH$e_- z@8v~rYYlKX>yz+`?^3wK)$U_4>_XlS<uh!m6Q%|@+$pTTxzQ=uq2*KN#aE(!4!vgQ zKD=j(blB%2C^45Yz+sZh!cERMd>o==uk4YxbaTiJ+rLcH$<5(Ln2ILXO&14Oe;2I{ z9YGF9)|?It4!-5<P&4bK$Hf*8hc8zJ-YdihIjoe__B+_&?QlIY>X@6LkAvaP>=}LE zA{+!BTYV_h40A}{wLl`LA<Q8xtc%6zuBXGnf{!}Yx`7TqAAeX9axlOFTq{kyU^!Vp nT!4u~22>3@h;uM$NPt)w;s#6-vJ(%=aWz0odJVaW2d4o5PskK2 delta 1951 zcmewmbi-hR22+H@M6Et10r8Ey*qE3UgeD%0o2<YrG+9c9PnMAZ0$?;V0|x_$EzH2c zz&v>`n=B^-2Sf$)<lC&0%)xt$CeLSH$aFzy;z7B|2`n5;0)i6{%5CmoVPYg`pf=b* zW~hPMNCr-XIf$8&fnoA&HhU&Lj>&h}6j*}y79~#H_kr0^&tS3#hju*^#4!xIP`X8U zl~|?!!UG~ue&Pu^#$Cr095$Fg{O4+^?(ir-r!8@T_<?__eMT9Uh6k!%nm4O`WIC{W z`+_xhb3fRp&187h^Gwwt@Rx|PVa-qb895K8`*SKE(AZfewlS6Oz_jl-CH+qbA2?-e z^l5{<ME!wPmsws_?hteM`Sx~P{5MgDD-+r+W9ErCL>ewW5HMNeK)qyr7UyI+hf`}W zGDIFwc2KITy*RN%*x|5y%bL1k35S1mjuUKI^&Mh1_)U1tqjcc5pKAFu?_c%?^84kh zO1KU<t@*a>e)QY@e|Ao8R}9g0h>nPxvG%f}gU_^2Uq2_*vpYy`>*zSg!|$N(zsy7O zyqJT|?6oe_jD!vZ{P^)!TVl=rNq(l1m-mYt5WHzL@!KAm1KUi#tlDXM-Tu{uXvam< zqz=eP2S4d(`n&&Y@N?_;(^VZzIyZ^V{I7T*hbR5&hwvBs*H->Wd;VGJfaD6L$<pH5 z2ew#=yxx6Sz@cvaC!U1*ZK4i2a@xjgnZMaDdj07p>q6lJ|2FBJ3gT6Ekeeo&u)|jQ zz{9NdlB$eK2d*S9czK?a$w85)P+{LGfdf2GMK-^_DeN#O(Ot0StiXW-iqn>aN^Rc1 zWNXC^D-opw(-{7#p08p)Fw5KiV1KdFfgRE@J&Nb$96bJTctkB$aZofin^pf`TG4^^ zS8ivF1c$@#k9RDlpH@2HKk=F3w=@6upUwy}Q~JkpV2_b%Zj8Z?{lQz)KCNn%a_}(v zp`$oc(;<B&)5BAFMhEIAX8)OPDsUi!^M2j2B`5Z$AK+^LHbK$h<)X7&e#S{Sc<6U= zPfq7?cyTjKf3E*~d%N$qj;fk69(eXbqF&mxLBruoQI6iTR=xwz1lP=6f05H++u^=# zUkW7-Xy+X7c~>ibz-!CbEWQxV1L7r_df^Y%9M%Z(Y<zr5+F_l6qtd0v@&_LEZ+*I= zPU^rmdHGkGoBrFgKMp*yl*`}%!{lk#W%y+e{7r6L_f=cLVM|M9vcSd9`_<>1oAJ?| z>A;$Q^(}vnII<q--*N7P4^OZ|<*UamO$8wikBUB2rET?bknv;uY8T?;z*F~V&bPaP z4q2X$;tV*#9psNZ`*-NRn?sV#rzsKcUJkc<_|leZxjIC=3W?Ry40O2cX2Lbs%+(>u zQd2~Ky{kih$iba{t-%hhkz5T~i-H`Q6uor&E(AFIuQzY`VImahaPEyvdq#<;Ll@&O z0af=92dCBk`RiUsIm}5@zZQ@g<e<&nkkHNH<KUpQa@~tb{tjEs_?}lyk8wy)HcCD< zHOyh&6{c{5SwRl2^FRNX{V&kLB{=!?PHi`bqqVB#kyeoo|Gg(YI{Y%o;Yxs~u4Y=0 z!|$V&duLcVJJk2eJ0%zG2yh5m=%1JO$-|*A{B++nkpPFlYfm<Y$Ok(l8@X}5RCjmy zl6QUs!<RsZ#<vj`BAG4@<>F^|+}Y^uu>Ong!zJ@V9QcjIB&Jk`JMcbhiMlEi=D@A( z8hB$uh=ZRY!!h2eehw_l+ZmUb_&CU!K1e<BBiLan=bH_AYXj>Yj12#2D%6HMEYx%f zPKXG1Fh~mPyR<3T!GGoc@7p>O9M+^SG&)}%>G0sB?x`t-Q4U}Do-igdxj6i@JO8bu zFv{Wh;~J}N+k+j}b)EaXXhn!aJa@avwB^wbs~46jd}R-JDD7O9J$I|0gWZOCx8sh6 zIQUMkTJ`L7h{Nfv;au8T^$`xL<y$s~O^9&d{x-44V_}#>PJl$I<&#K<j6O#3rCkvY z`CG65S@s~#!T#a;^=}Qr9KMV1?%%jS(jhEdN>N=d!XYidW5ugu!43)QcX?ehNpPt6 z5nGV+Khz;PCML!uD9phyTXUCLdVs^^S)yw_djcF9=0v~dp6KK7;p6}0`_UZH4hl7l za<eamIh<X&%X+d$fW!02C60SehdLzOI_SLOXMn>~=O@tu&Or{ht`zf!)FwJSY+Yy9 z#TM((8tNz75+CP~@KD>D?OTw;)eQ<Iq0d4c&iQXotYr*y*gTPgi3y_~Fhi~fI#8s& zCEc0qz|!oXYJ!2GV&cKp$pKQkKot?ZB5?z&fL0_h10opxCRZ~nBJ|aQP1cgl;EID* zPZyXb-(i)2s>qmlFc+>uj~i0iV5%@+VVkU{!Nv4Ic;Z1hE)R$@1_l98UD=>v!1O@` P#JZrt!Nedo@!&K7-V_k% diff --git a/bob/bio/gmm/test/data/gmm_projector.hdf5 b/bob/bio/gmm/test/data/gmm_projector.hdf5 index ebeb5462e0ce7a3d28f03fd77ef20cfb1eee1520..a25006d8ede0f7e5055b75c1a4c828d09f5e7c3b 100644 GIT binary patch delta 2493 zcmZqh`;jt1gQ>!3qE;W9h6V!z1H)ti_74*;{F*Ev$T4{XlK@kL*2E||t_Rv6IR*v` z$%zNeCMz%td2)aS7{Me112cmJ1Bflj$jBhTzyQV|J_7?p9V-I|Se}UqOfhgUNPu-R zGcru}7Exp__gyu4KJ!AR7m^bX%1usSk&r^z2v*3D0ihUvNP$#rV7cW5F&SbDB`y$< znCv5>0CvKD`NfPJlUK9ZGbu<+e#54~Uhcapk%3{t#Djl5m_gox*a`N7yR*L!*m976 z85qD~AUA^4foMZL0}#c)1ST06bfNU*1P&zwUtizE<c!R`RCPNBklP`0Fr~RU5I&4- z;NSdF>O12m=@&Th#0%<zAR7drE(96HG`Wyfl4`-Ep#Tr2dV!ge5m3ijK<NlcC}Z+M zC3{ANi3j~b$%K((@?$}FCJO~nXrcwN16U9d#8W3f5YA%qKv?51G?|fOa<zmekDH@U zuqy)tLr74l>*Ng*N=y+7lXFBNafU}#B(kc;i5EW9gMtMXM4*WJv-0C5-y{cna6~dN zFuW-h+h3gd&|XF4a-)RqU3*4GMy^li<&F29-oLyq(DP&9A$#p7t2MrIoU)gV=;Xh{ z`@sI#L)ErNGp_Hym?}3(*8l1L$KP+*#``?k-)$SPP5#;{`!?4NO3`vR?R5iJT&<VB zR<>U{tvQd!ZS(%GF4~!1pHJ^E4Qrhw@@3clfBwZ&V)h=hFT2IZC@lBfUg@FX<|oUa z+8^A0>80Gwo%TsaN2TYNp0J;;8M!~^&<=aC_Y*l)<{hwi+RZL@o9(K-t%Xkf8KblN zndXX^zF551ew)j~EqsMX_8%$T$9nU?Zu>g(x%Fks58bjaNp!p;A9Zd2@>zWX`=?E{ zzjD|9xc>j+_Wb+SeOB|hzF%ye!g221ANR+qt{3lmzifZ`jjJh9H&59MXWdx*@Zz!k z`*T*zV~IGtf1+^7)C)!T?cKC*DbC8jWPe-6g(<Idw|(r^eZDf<XZP3ceI6Nkb+3I( z(Q-aRuCw-6yX$90eihqi&$4=Yv(?V0`<E`sGe5fd^8Q6m>5tf7%&}jz*5PYk?n8UE z4S}UkQts|g_jVAUH0AF8<3@7A8u1tQ^9bD6UhQ_({+yoJ!)mF!_9x}mDb47)W^c|C zRI0dW_kK&4B}e-nKG-i(wEK|L(uMYu_3HEf3!dG-r|*<?yTEz-x%K^t8Q-UD*uVX& zDQCu&Bm1LgOg+>0?bH4*R}cT3vhU^o;5|FTLyXSvcah6jU0ii}f0h1?6)z1g+XuNn z<&-WzxW8QE;UjO&d;4qRe#*}`zh=Ma$Ej`oFBaNAKQ`@Y;J(fFQihD%Z|Wb~zrR|w zK|S!!ej&*M&G!~p?JJLm7Oz!*XkYIyotPewddc49V>F-B?sN9vn%2uMQTS`$uhP3I zpZUOk>77p^-kn@zU*Yp1ByqxH`{~+0uCc#9Zm*}H#Obj7`TjfCBPI&HzP-P|Gry$e z@d5koKmVzn*S&AQ{d}QO|Kr{EE-xzhc?>u0FX`O8g?G;``x~JLHRecMwV$vupY4<R z&3gMFsY@S3Gf(dSX|p>_qfOr7*7pDEexZsE^5;T@z1tNWc--RlJnffvSokzlCTW9$ z1K+|6Hx-{KIQ-19TgUlL&Oum6e(Ls5iVg-f?-|O?6&>skK07bHRl(u?jr7u-XR;0w zxt{N4T~cr;?p&2DTqNhPX}Mmy?qPX{`wjb!bWW?6a}b!S_N->BoI~6fz60*R6dbPq z4&!I5P;j_(^Ov7ior1#&o6cO}i*gRTH5Y76^-yrwn7#MQj%NxEsxIt$flCw}4*n63 z-F{rb!O;Ee3CVN?hbJBqMiwSY4qpRjHDCL#=pdB1u*J<x!Qsu7$It!y6dk0h4VhP4 zDmomh3U2bMXOeNa<a>I}vNlDBTR|I@>MRr-mN^A+#;sCtklpBD!+1==q3HAHy`k(1 z4#n9n2LqofI0)@|S9eKL$w5kI(VV?kWE^g@u3sw5Bj-@h5Uk6ZrQpyYHS@3SDMbe> zlWh))x8xmO?X+9hvsuBxrTEmnzCuL@F*aisxg+up*FsibsMlJd=<w}eO<Mjo1&8Y5 z^MUR66dYKdC|!8}M!`WpDbdl=K;GfU(erJ!M`Ro}g-;Xxb5+sd|F<i~9re-<s!l>) zvL5mdKT{Ng{d46V_R6m)$mdaV*icoh?kKC^VAPd9Gjh9}!@K_J`RThP9jr}Lm22P0 zImo#B-oCS0-oYbFq<*RV21SQEpE=SFpHy(Dar?f(@PwR0aIS8EN~WU2&6CVsS~c<x z4m_r27S9zN7&ZDQ|BID#u(5D-TDDH!A?)~T=eeSa4tgIQm!8@q=V0(%;=0mrS%=To zLRyuqiVjsdYkp7dRC2f~|EYO-m88Q|%LQ6ZlNB5)^|rj9w_nkLX-*oO@n0qd2mM7~ z3$FMoIP6Pz`D;5-!NFQ!!OTa#3J#Upe?P<tDmoZ!^<TF{PR4;FDBO|bp@PGaTL&kH zImtVGVb3Tm$X9R>PF%KWS*oIgUhm7mkD-bVyoR%~z8#czfYd-6Y{0eU#0z?p1;i~F zH6|Yvm*?t$$T3LBf|?H##5tHW<R%`J<JtffFVF;uKM?0&<d|$ICqH?EXc8lix+e;) zR^ixOAX&l4DWDHAN?`LxDNd%zNeUTE(F&8#C_rkd`C{cf%+SUV0|SE|$hH7E4yFu4 mkSm}~2!;rQi3gue4p69K(pH9R?H6r^R-?aF;nnDT=7|7plWf!g literal 9984 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo0fRb32+I8r;W02IKpBisx&unDV1h6h89<PM zK?1^M5QLhKt}Z0V)s=yPi2-IljD~7sFkpeO6d)9XfH;KF0HH?7VIBe=u8sj9FD5_} z+654Yfq@}I62btbGLRH{Iy)f+krxnz(4>-#3~US_b0Gu+10ysgGctljVZv~lnSle$ z7G_{zU}j)oV1@FTn3%vCIH3BO85uYrZUL!+ut1~`R2ZZ<cyCc60|P9lfMf?A2mATE zFfy<*fFpx}f#Cpwa6JHZ@(Fl(01-fsmkfw7!v!HI0|x~SGzJCKV8c)iD<>*gF{B2C z043>47jL>^h6WKx6~e27JzbF=9s=NeCcwb}%`pyO&*IK+2GDSUQef2#44C0EvT77V zJp-HMoW$Z{1@)x-B)y!}#G*XC)LgyX%!-oIqEx-~++00(UtizE<c!R`RCPOOHfG4p z`LpulCEp|mdr(OQ3d|2q%hDHE-nZ|(r*lR>`GGwnBO_6UCzci$XC~$qtJ`@_m>bV? zYomQv>(=)Qw;%71;m$rJ*z;h&LXYb~z0Sw>ro{)es&+ivFV@B}%UNNkJ$G%{-Pq@k z_Se7fIIR<XWq)a`knFZaNA~mjOr2-cdu0E%^Y{I4>^`~w?z#yP>Iu*Gr`$RrAV1^D z{@H)G^E|0quwVZaU&gidi}tS)z0dzDaix9w=j9j9O*>+LIbru^-lm)OGjt|&uVQ^< zZ?P=HxaIDF{Q@yhUp`#+*?yb1(|xJni}s)QUw*M){gC~Xhay{q{m$;sdHv@9g~;>! z&pG~>QXO(@f6VE`w`Q-N>|ZUD#~-@nhJB@CPonUyJNv&YIJftA?zfj#|5)yD@6i4~ zW<uZQtU15mp*Q+d){Zv&;BeCm46hH{pMH9mWy9TD``750ZVYj#u+Lt9wr)f5)%}`! zvwNNOSL|oBs1a#3zPNvTXP^FsvYYmEE(&z-6T7&-Ywd@!LvrWsZ{0rAbM5Qi{SLpU z=@z=Y*?&`@F-zjZE&H@x-$k2+PuN?hx3!d}UA4cG^*hg={j&YBzdR?7uDiHjHf7i9 zswt1`(_WW0yo|oNU#IGZ-L6+B?JauRpWZ!p!M=Ojg_X~4?X>?}K5gQ>{m1OrZkX__ z+j#r_wCxv8RGfQazr#c*SSaO=y_;u3bd}g<dv4eJAKi2=?BA~1z3ukaW&8IX4BUSD z@`e4jx!uaEm)_Zb)vxOM@#v%bZ5Ezc>aKgp-Y>y6xAo@d{S&`DTeqa@#Qw5|xju8M zC)*cxp0}IzbFcj@DR(>loHO?BmD&9D{&(!J{r2<qtAA*JRdKfJh5rZapJ$}s7PWt4 z|L}+7s(%wN?QdDEx@P{GNA}NO9-m&aVzd3%!0-%>==1gs#j|rR6`Zn9&M`R5_Wh>) z_7(fih|3+?|3zlC*5-(F_Fq;m^|IY|(SDY1|NEUw@7QbeS@pbfJhH!*>)*GxD{kAf zA7Pj$V9~Q*!l&?hqWzWq@q0h*aM<wLzVX_T=l|Ot?w31~R`vX=oP)mKwD-!6iVl29 zDoQ2w3Jz+5XMK)3D>wwco6EY~T;9Q%Te_V|OTnT0tDLT(n2f_&T?6i|VM-2R57vEk z6IXCJ{bgtW?G^<G=^WX)oDT91^CAwrACXmX=q?oHa_f?FuwJr5@j|YG!_mUWHLL!} zIrPli^C*N<(Lvhk#n#;U3J&w4&P<;ttmv?8iLP?|Wd(;ai$^_A%;g<scnVik2T3`^ z-7;)6T&Cb~=eH+UF_)r4SkTmMq3!YxXEhkD9vqZ+m~d!=Rr@L>hti-6+$KJX4#yJh z!@u8^b=YK|r{Xd}(c$&ZGtYwNDLVX?f9@2uT*<*FsU`5ezM{k4nMwLLE-5&;_c30M zu2yh(e(J=5uBQqPak4(=^jQ@gWLC%@Klw<(L3u^VonIg29STj?@~+z~>+pEV@qjP- zat=P$7b@?U$vfCagip;YP;i)Va=`|swF(X`D_hPS?pAR4fBzh#khOxt3XjVXi`FYR zxK*FBZq`tANWT4e=ZgjfhdB&V->(WQIQ;PM+dKKVg2Q44*A<=_iVo=_A2@CqD>!ib zEo)%aQFQQm#OWaPQr=<FgQ&xcm&-deM@KW1{gZXD@;B^}U{Q3C-4q<Mbg`_1io*ZA zh`+K97hbPBusmDAVUomC(UtGz9olF2tvukM;NTS_<F~L`!C}$?yXS{q$vgb+dGPjq zik!nrn?(^j@5(y7VySL<YN6mzY_MBBw@1<8-?|8y1<nc%vaya+SsWA`rg{l&_<B^q z;m~|OJH|C~4i$N<MYB5<94sYIOv>CM=g{<gxlzA~l7n3npJl&|f<uvo_f(A+3J(8o zonD&sTfxEpuXw-e9R&v){$<B><K!KlpSJsJ|4q^1iqk6z^N9)$okCmN(xsIhp0|X& zER2?RC^jv$?($P`_`SYi#iLL14%^vp#dn-jbWrvYwAfRw=)k!%_sQW4vJPsSG|N=O z6(AKXsHh*s!zKhsZ+9X3H=wov?)H`iw4nu3g$ORp_SWF)-@y8r8=y{x^-BtrpaO%8 zB0c^j!0l#11$%hA*@T<H<Uh1Og*!g2pz#T%py7ZSpA}G<fuTqb4^sOPPS5~=Qc!PV zhQr|KM-U1RSnh<;%nanF4=-qdppGA6hRaY(9~$8NBEbY2`(a=J^*13s9u9`I#GK+( z1_p+bqEZkAR)9PIL_xy`q8plQu=GD+B@RplZD~U31Lj&7&CEcp@Q5Ci;Q<-nVq>s? z273n^l!1ftfQh3fdzdJkhShQ(MBqYL88cv_VlW1F8dg6=NMM(TiNZ>u2cQ8-m>?bK zflF60gEC#gf;>01Br!fMGbc5^EVZaOGd~Y10TwSW$}CAOiU*I!K*jUo(_sT7V3Crd z#LT?ZB8KGryt34y^whlM)cBH&qSWGy{G1d928P_kiukn5l9J54^!VbE)PiELE(V73 z)XemZl46F^f|SIP)cD-g#5}MRlwX!ul$n^9oC+2LX@QEvv@xK%0$kLAlrc;IlLJpc z=8qV0&L91dg}8CF69Wl{kw%hU9wGWUpr$eIdOa3e-hfmgLRARbx(3yogLxhYR?h{% zCQJ}sKw&{W$?!r3MHqtvGsZz4LkdLzHh-j`j3SJ|f#uf(MGUDyAwWsG(!o2AlK>5J zkSc^%2Yb503=a*k;~62-pA4{m25fc;olk!MqX!yZ7hpz$?ZphQfo@;ZD?DPj!2!T9 zxWWVCCI$v@dxwN}5Ujz1-cW(@5se^769!c7g6x8~&p{kA@W7>K0~LsK2V?wguz|R3 OlpYNSPzVg&bN~SDj5CD* diff --git a/bob/bio/gmm/test/test_algorithms.py b/bob/bio/gmm/test/test_algorithms.py index 1ee7f9b..ad4f79c 100644 --- a/bob/bio/gmm/test/test_algorithms.py +++ b/bob/bio/gmm/test/test_algorithms.py @@ -25,6 +25,10 @@ import sys import numpy import pkg_resources +from bob.bio.gmm.algorithm import GMM + +from bob.learn.em.mixture import GMMMachine + import bob.bio.gmm import bob.io.base import bob.io.base.test_utils @@ -75,9 +79,9 @@ def _compare_complex( def test_gmm(): temp_file = bob.io.base.test_utils.temporary_filename() gmm1 = bob.bio.base.load_resource( - "gmm", "bioalgorithm", preferred_package="bob.bio.gmm" + "gmm", "algorithm", preferred_package="bob.bio.gmm" ) - assert isinstance(gmm1, bob.bio.gmm.bioalgorithm.GMM) + assert isinstance(gmm1, GMM) assert isinstance( gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm ) @@ -86,7 +90,7 @@ def test_gmm(): gmm1.number_of_gaussians = 2 # create smaller GMM object - gmm2 = bob.bio.gmm.bioalgorithm.GMM( + gmm2 = GMM( number_of_gaussians=2, kmeans_training_iterations=1, ubm_training_iterations=1, @@ -101,7 +105,7 @@ def test_gmm(): ) try: # train the projector - gmm2.train_projector(train_data, temp_file) + gmm2.fit(train_data).ubm.save(temp_file) assert os.path.exists(temp_file) @@ -109,8 +113,8 @@ def test_gmm(): shutil.copy(temp_file, reference_file) # check projection matrix - gmm1.load_projector(reference_file) - gmm2.load_projector(temp_file) + gmm1.ubm = GMMMachine.from_hdf5(reference_file) + gmm2.ubm = GMMMachine.from_hdf5(temp_file) assert gmm1.ubm.is_similar_to(gmm2.ubm) finally: @@ -140,575 +144,576 @@ def test_gmm(): ) # compare model with probe - probe = gmm1.read_feature( - pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") - ) - reference_score = -0.01992773 - assert ( - abs(gmm1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score) - assert ( - abs(gmm1.score_for_multiple_probes(model, [probe, probe]) - reference_score) - < 1e-5 - ) - - -def test_gmm_regular(): - - temp_file = bob.io.base.test_utils.temporary_filename() - gmm1 = bob.bio.base.load_resource( - "gmm-regular", "algorithm", preferred_package="bob.bio.gmm" - ) - assert isinstance(gmm1, bob.bio.gmm.bioalgorithm.GMMRegular) - assert isinstance(gmm1, bob.bio.gmm.bioalgorithm.GMM) - assert isinstance( - gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm - ) - assert not gmm1.performs_projection - assert not gmm1.requires_projector_training - assert not gmm1.use_projected_features_for_enrollment - assert gmm1.requires_enroller_training - - # create smaller GMM object - gmm2 = bob.bio.gmm.bioalgorithm.GMMRegular( - number_of_gaussians=2, - kmeans_training_iterations=1, - gmm_training_iterations=1, - INIT_SEED=seed_value, - ) - - train_data = utils.random_training_set( - (100, 45), count=5, minimum=-5.0, maximum=5.0 - ) - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/gmm_projector.hdf5" - ) - try: - # train the enroler - gmm2.train_enroller([train_data], temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - gmm1.load_enroller(reference_file) - gmm2.load_enroller(temp_file) - - assert gmm1.ubm.is_similar_to(gmm2.ubm) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # enroll model from random features - enroll = utils.random_training_set((20, 45), 5, -5.0, 5.0, seed=21) - model = gmm1.enroll(enroll) - assert isinstance(model, bob.learn.em.mixture.GMMMachine) - _compare( - model, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_model.hdf5"), - gmm1.write_model, - gmm1.read_model, - ) - - # generate random probe feature - probe = utils.random_array((20, 45), -5.0, 5.0, seed=84) - - # compare model with probe - reference_score = -0.40840148 - assert ( - abs(gmm1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score) - # TODO: not implemented - # assert abs(gmm1.score_for_multiple_probes(model, [probe, probe]) - reference_score) < 1e-5 - - -def test_isv(): - temp_file = bob.io.base.test_utils.temporary_filename() - isv1 = bob.bio.base.load_resource( - "isv", "algorithm", preferred_package="bob.bio.gmm" - ) - assert isinstance(isv1, bob.bio.gmm.algorithm.ISV) - assert isinstance(isv1, bob.bio.gmm.algorithm.GMM) - assert isinstance(isv1, bob.bio.base.algorithm.Algorithm) - assert isv1.performs_projection - assert isv1.requires_projector_training - assert isv1.use_projected_features_for_enrollment - assert isv1.split_training_features_by_client - assert not isv1.requires_enroller_training - - # create smaller GMM object - isv2 = bob.bio.gmm.algorithm.ISV( - number_of_gaussians=2, - subspace_dimension_of_u=10, - kmeans_training_iterations=1, - gmm_training_iterations=1, - isv_training_iterations=1, - INIT_SEED=seed_value, - ) - - train_data = utils.random_training_set_by_id( - (100, 45), count=5, minimum=-5.0, maximum=5.0 - ) - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/isv_projector.hdf5" - ) - try: - # train the projector - isv2.train_projector(train_data, temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - isv1.load_projector(reference_file) - isv2.load_projector(temp_file) - - assert isv1.ubm.is_similar_to(isv2.ubm) - assert isv1.isvbase.is_similar_to(isv2.isvbase) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # generate and project random feature - feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) - projected = isv1.project(feature) - assert isinstance(projected, (list, tuple)) - assert len(projected) == 2 - assert isinstance(projected[0], bob.learn.em.GMMStats) - assert isinstance(projected[1], numpy.ndarray) - _compare_complex( - projected, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/isv_projected.hdf5"), - isv1.write_feature, - isv1.read_feature, - ) - - # enroll model from random features - random_features = utils.random_training_set( - (20, 45), count=5, minimum=-5.0, maximum=5.0 - ) - enroll_features = [isv1.project(feature) for feature in random_features] - model = isv1.enroll(enroll_features) - assert isinstance(model, bob.learn.em.ISVMachine) - _compare( - model, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/isv_model.hdf5"), - isv1.write_model, - isv1.read_model, - ) - - # compare model with probe - probe = isv1.read_feature( - pkg_resources.resource_filename("bob.bio.gmm.test", "data/isv_projected.hdf5") - ) - reference_score = 0.02136784 - assert ( - abs(isv1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (isv1.score(model, probe), reference_score) - # assert abs(isv1.score_for_multiple_probes(model, [probe]*4) - reference_score) < 1e-5, isv1.score_for_multiple_probes(model, [probe, probe]) - # TODO: Why is the score not identical for multiple copies of the same probe? - assert ( - abs(isv1.score_for_multiple_probes(model, [probe, probe]) - reference_score) - < 1e-4 - ), isv1.score_for_multiple_probes(model, [probe, probe]) - - -def test_jfa(): - temp_file = bob.io.base.test_utils.temporary_filename() - jfa1 = bob.bio.base.load_resource( - "jfa", "algorithm", preferred_package="bob.bio.gmm" - ) - assert isinstance(jfa1, bob.bio.gmm.algorithm.JFA) - assert isinstance(jfa1, bob.bio.gmm.algorithm.GMM) - assert isinstance(jfa1, bob.bio.base.algorithm.Algorithm) - assert jfa1.performs_projection - assert jfa1.requires_projector_training - assert jfa1.use_projected_features_for_enrollment - assert not jfa1.split_training_features_by_client - assert jfa1.requires_enroller_training - - # create smaller JFA object - jfa2 = bob.bio.gmm.algorithm.JFA( - number_of_gaussians=2, - subspace_dimension_of_u=2, - subspace_dimension_of_v=2, - kmeans_training_iterations=1, - gmm_training_iterations=1, - jfa_training_iterations=1, - INIT_SEED=seed_value, - ) - - train_data = utils.random_training_set( - (100, 45), count=5, minimum=-5.0, maximum=5.0 - ) - # reference is the same as for GMM projection - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/gmm_projector.hdf5" - ) - try: - # train the projector - jfa2.train_projector(train_data, temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - jfa1.load_projector(reference_file) - jfa2.load_projector(temp_file) - - assert jfa1.ubm.is_similar_to(jfa2.ubm) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # generate and project random feature - feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) - projected = jfa1.project(feature) - assert isinstance(projected, bob.learn.em.GMMStats) - _compare( - projected, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5"), - jfa1.write_feature, - jfa1.read_feature, - ) - - # enroll model from random features - random_features = utils.random_training_set_by_id( - (20, 45), count=5, minimum=-5.0, maximum=5.0 - ) - train_data = [ - [jfa1.project(feature) for feature in client_features] - for client_features in random_features - ] - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/jfa_enroller.hdf5" - ) - try: - # train the projector - jfa2.train_enroller(train_data, temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - jfa1.load_enroller(reference_file) - jfa2.load_enroller(temp_file) - - assert jfa1.jfa_base.is_similar_to(jfa2.jfa_base) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # enroll model from random features - random_features = utils.random_training_set( - (20, 45), count=5, minimum=-5.0, maximum=5.0 - ) - enroll_features = [jfa1.project(feature) for feature in random_features] - model = jfa1.enroll(enroll_features) - assert isinstance(model, bob.learn.em.JFAMachine) - _compare( - model, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/jfa_model.hdf5"), - jfa1.write_model, - jfa1.read_model, - ) - - # compare model with probe - probe = jfa1.read_feature( - pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") - ) - reference_score = 0.02225812 - assert ( - abs(jfa1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (jfa1.score(model, probe), reference_score) - # TODO: implement that - # assert abs(jfa1.score_for_multiple_probes(model, [probe, probe]) - reference_score) < 1e-5, jfa1.score_for_multiple_probes(model, [probe, probe]) - - -def test_ivector_cosine(): - temp_file = bob.io.base.test_utils.temporary_filename() - ivec1 = bob.bio.base.load_resource( - "ivector-cosine", "algorithm", preferred_package="bob.bio.gmm" - ) - assert isinstance(ivec1, bob.bio.gmm.algorithm.IVector) - assert isinstance(ivec1, bob.bio.gmm.algorithm.GMM) - assert isinstance(ivec1, bob.bio.base.algorithm.Algorithm) - assert ivec1.performs_projection - assert ivec1.requires_projector_training - assert ivec1.use_projected_features_for_enrollment - assert ivec1.split_training_features_by_client - assert not ivec1.requires_enroller_training - - # create smaller IVector object - ivec2 = bob.bio.gmm.algorithm.IVector( - number_of_gaussians=2, - subspace_dimension_of_t=2, - kmeans_training_iterations=1, - tv_training_iterations=1, - INIT_SEED=seed_value, - ) - - train_data = utils.random_training_set( - (100, 45), count=5, minimum=-5.0, maximum=5.0 - ) - train_data = [train_data] - - # reference is the same as for GMM projection - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector_projector.hdf5" - ) - try: - # train the projector - - ivec2.train_projector(train_data, temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - ivec1.load_projector(reference_file) - ivec2.load_projector(temp_file) - - assert ivec1.ubm.is_similar_to(ivec2.ubm) - assert ivec1.tv.is_similar_to(ivec2.tv) - assert ivec1.whitener.is_similar_to(ivec2.whitener) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # generate and project random feature - feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) - projected = ivec1.project(feature) - _compare( - projected, - pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector_projected.hdf5" - ), - ivec1.write_feature, - ivec1.read_feature, - ) - - # enroll model from random features - random_features = utils.random_training_set( - (20, 45), count=5, minimum=-5.0, maximum=5.0 - ) - enroll_features = [ivec1.project(feature) for feature in random_features] - model = ivec1.enroll(enroll_features) - _compare( - model, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/ivector_model.hdf5"), - ivec1.write_model, - ivec1.read_model, - ) - - # compare model with probe - probe = ivec1.read_feature( - pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector_projected.hdf5" - ) - ) - reference_score = -0.00187151 - assert ( - abs(ivec1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (ivec1.score(model, probe), reference_score) - # TODO: implement that - assert ( - abs(ivec1.score_for_multiple_probes(model, [probe, probe]) - reference_score) - < 1e-5 - ) - - -def test_ivector_plda(): - temp_file = bob.io.base.test_utils.temporary_filename() - ivec1 = bob.bio.base.load_resource( - "ivector-plda", "algorithm", preferred_package="bob.bio.gmm" - ) - ivec1.use_plda = True - - # create smaller IVector object - ivec2 = bob.bio.gmm.algorithm.IVector( - number_of_gaussians=2, - subspace_dimension_of_t=10, - kmeans_training_iterations=1, - tv_training_iterations=1, - INIT_SEED=seed_value, - use_plda=True, - plda_dim_F=2, - plda_dim_G=2, - plda_training_iterations=2, - ) - - train_data = utils.random_training_set_by_id( - (100, 45), count=5, minimum=-5.0, maximum=5.0 - ) - - # reference is the same as for GMM projection - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector2_projector.hdf5" - ) - try: - # train the projector - - ivec2.train_projector(train_data, temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - ivec1.load_projector(reference_file) - ivec2.load_projector(temp_file) - - assert ivec1.ubm.is_similar_to(ivec2.ubm) - assert ivec1.tv.is_similar_to(ivec2.tv) - assert ivec1.whitener.is_similar_to(ivec2.whitener) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # generate and project random feature - feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) - projected = ivec1.project(feature) - _compare( - projected, - pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector2_projected.hdf5" - ), - ivec1.write_feature, - ivec1.read_feature, - ) - - # enroll model from random features - random_features = utils.random_training_set( - (20, 45), count=5, minimum=-5.0, maximum=5.0 - ) - enroll_features = [ivec1.project(feature) for feature in random_features] - - model = ivec1.enroll(enroll_features) - _compare( - model, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/ivector2_model.hdf5"), - ivec1.write_model, - ivec1.read_model, - ) - - # compare model with probe - probe = ivec1.read_feature( - pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector2_projected.hdf5" - ) - ) - logger.info("%f" % ivec1.score(model, probe)) - reference_score = 1.21879822 - assert ( - abs(ivec1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (ivec1.score(model, probe), reference_score) - assert ( - abs(ivec1.score_for_multiple_probes(model, [probe, probe]) - reference_score) - < 1e-5 - ) - - -def test_ivector_lda_wccn_plda(): - temp_file = bob.io.base.test_utils.temporary_filename() - ivec1 = bob.bio.base.load_resource( - "ivector-lda-wccn-plda", "algorithm", preferred_package="bob.bio.gmm" - ) - ivec1.use_lda = True - ivec1.use_wccn = True - ivec1.use_plda = True - # create smaller IVector object - ivec2 = bob.bio.gmm.algorithm.IVector( - number_of_gaussians=2, - subspace_dimension_of_t=10, - kmeans_training_iterations=1, - tv_training_iterations=1, - INIT_SEED=seed_value, - use_lda=True, - lda_dim=3, - use_wccn=True, - use_plda=True, - plda_dim_F=2, - plda_dim_G=2, - plda_training_iterations=2, - ) - - train_data = utils.random_training_set_by_id( - (100, 45), count=5, minimum=-5.0, maximum=5.0 - ) - - # reference is the same as for GMM projection - reference_file = pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector3_projector.hdf5" - ) - try: - # train the projector - - ivec2.train_projector(train_data, temp_file) - - assert os.path.exists(temp_file) - - if regenerate_refs: - shutil.copy(temp_file, reference_file) - - # check projection matrix - ivec1.load_projector(reference_file) - ivec2.load_projector(temp_file) - - assert ivec1.ubm.is_similar_to(ivec2.ubm) - assert ivec1.tv.is_similar_to(ivec2.tv) - assert ivec1.whitener.is_similar_to(ivec2.whitener) - finally: - if os.path.exists(temp_file): - os.remove(temp_file) - - # generate and project random feature - feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) - projected = ivec1.project(feature) - _compare( - projected, - pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector3_projected.hdf5" - ), - ivec1.write_feature, - ivec1.read_feature, - ) - - # enroll model from random features - random_features = utils.random_training_set( - (20, 45), count=5, minimum=-5.0, maximum=5.0 - ) - enroll_features = [ivec1.project(feature) for feature in random_features] - model = ivec1.enroll(enroll_features) - _compare( - model, - pkg_resources.resource_filename("bob.bio.gmm.test", "data/ivector3_model.hdf5"), - ivec1.write_model, - ivec1.read_model, - ) - - # compare model with probe - probe = ivec1.read_feature( - pkg_resources.resource_filename( - "bob.bio.gmm.test", "data/ivector3_projected.hdf5" - ) - ) - reference_score = 0.2954148598 - assert ( - abs(ivec1.score(model, probe) - reference_score) < 1e-5 - ), "The scores differ: %3.8f, %3.8f" % (ivec1.score(model, probe), reference_score) - assert ( - abs(ivec1.score_for_multiple_probes(model, [probe, probe]) - reference_score) - < 1e-5 - ) + # TODO YD 2021 + # probe = gmm1.read_feature( + # pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") + # ) + # reference_score = -0.01992773 + # assert ( + # abs(gmm1.score(model, probe) - reference_score) < 1e-5 + # ), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score) + # assert ( + # abs(gmm1.score_for_multiple_probes(model, [probe, probe]) - reference_score) + # < 1e-5 + # ) + + +# def test_gmm_regular(): + +# temp_file = bob.io.base.test_utils.temporary_filename() +# gmm1 = bob.bio.base.load_resource( +# "gmm-regular", "algorithm", preferred_package="bob.bio.gmm" +# ) +# assert isinstance(gmm1, bob.bio.gmm.algorithm.GMMRegular) +# assert isinstance(gmm1, bob.bio.gmm.algorithm.GMM) +# assert isinstance( +# gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm +# ) +# assert not gmm1.performs_projection +# assert not gmm1.requires_projector_training +# assert not gmm1.use_projected_features_for_enrollment +# assert gmm1.requires_enroller_training + +# # create smaller GMM object +# gmm2 = bob.bio.gmm.algorithm.GMMRegular( +# number_of_gaussians=2, +# kmeans_training_iterations=1, +# gmm_training_iterations=1, +# INIT_SEED=seed_value, +# ) + +# train_data = utils.random_training_set( +# (100, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/gmm_projector.hdf5" +# ) +# try: +# # train the enroler +# gmm2.train_enroller([train_data], temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# gmm1.load_enroller(reference_file) +# gmm2.load_enroller(temp_file) + +# assert gmm1.ubm.is_similar_to(gmm2.ubm) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # enroll model from random features +# enroll = utils.random_training_set((20, 45), 5, -5.0, 5.0, seed=21) +# model = gmm1.enroll(enroll) +# assert isinstance(model, bob.learn.em.mixture.GMMMachine) +# _compare( +# model, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_model.hdf5"), +# gmm1.write_model, +# gmm1.read_model, +# ) + +# # generate random probe feature +# probe = utils.random_array((20, 45), -5.0, 5.0, seed=84) + +# # compare model with probe +# reference_score = -0.40840148 +# assert ( +# abs(gmm1.score(model, probe) - reference_score) < 1e-5 +# ), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score) +# # TODO: not implemented +# # assert abs(gmm1.score_for_multiple_probes(model, [probe, probe]) - reference_score) < 1e-5 + + +# def test_isv(): +# temp_file = bob.io.base.test_utils.temporary_filename() +# isv1 = bob.bio.base.load_resource( +# "isv", "algorithm", preferred_package="bob.bio.gmm" +# ) +# assert isinstance(isv1, bob.bio.gmm.algorithm.ISV) +# assert isinstance(isv1, bob.bio.gmm.algorithm.GMM) +# assert isinstance(isv1, bob.bio.base.algorithm.Algorithm) +# assert isv1.performs_projection +# assert isv1.requires_projector_training +# assert isv1.use_projected_features_for_enrollment +# assert isv1.split_training_features_by_client +# assert not isv1.requires_enroller_training + +# # create smaller GMM object +# isv2 = bob.bio.gmm.algorithm.ISV( +# number_of_gaussians=2, +# subspace_dimension_of_u=10, +# kmeans_training_iterations=1, +# gmm_training_iterations=1, +# isv_training_iterations=1, +# INIT_SEED=seed_value, +# ) + +# train_data = utils.random_training_set_by_id( +# (100, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/isv_projector.hdf5" +# ) +# try: +# # train the projector +# isv2.train_projector(train_data, temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# isv1.load_projector(reference_file) +# isv2.load_projector(temp_file) + +# assert isv1.ubm.is_similar_to(isv2.ubm) +# assert isv1.isvbase.is_similar_to(isv2.isvbase) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # generate and project random feature +# feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) +# projected = isv1.project(feature) +# assert isinstance(projected, (list, tuple)) +# assert len(projected) == 2 +# assert isinstance(projected[0], bob.learn.em.GMMStats) +# assert isinstance(projected[1], numpy.ndarray) +# _compare_complex( +# projected, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/isv_projected.hdf5"), +# isv1.write_feature, +# isv1.read_feature, +# ) + +# # enroll model from random features +# random_features = utils.random_training_set( +# (20, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# enroll_features = [isv1.project(feature) for feature in random_features] +# model = isv1.enroll(enroll_features) +# assert isinstance(model, bob.learn.em.ISVMachine) +# _compare( +# model, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/isv_model.hdf5"), +# isv1.write_model, +# isv1.read_model, +# ) + +# # compare model with probe +# probe = isv1.read_feature( +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/isv_projected.hdf5") +# ) +# reference_score = 0.02136784 +# assert ( +# abs(isv1.score(model, probe) - reference_score) < 1e-5 +# ), "The scores differ: %3.8f, %3.8f" % (isv1.score(model, probe), reference_score) +# # assert abs(isv1.score_for_multiple_probes(model, [probe]*4) - reference_score) < 1e-5, isv1.score_for_multiple_probes(model, [probe, probe]) +# # TODO: Why is the score not identical for multiple copies of the same probe? +# assert ( +# abs(isv1.score_for_multiple_probes(model, [probe, probe]) - reference_score) +# < 1e-4 +# ), isv1.score_for_multiple_probes(model, [probe, probe]) + + +# def test_jfa(): +# temp_file = bob.io.base.test_utils.temporary_filename() +# jfa1 = bob.bio.base.load_resource( +# "jfa", "algorithm", preferred_package="bob.bio.gmm" +# ) +# assert isinstance(jfa1, bob.bio.gmm.algorithm.JFA) +# assert isinstance(jfa1, bob.bio.gmm.algorithm.GMM) +# assert isinstance(jfa1, bob.bio.base.algorithm.Algorithm) +# assert jfa1.performs_projection +# assert jfa1.requires_projector_training +# assert jfa1.use_projected_features_for_enrollment +# assert not jfa1.split_training_features_by_client +# assert jfa1.requires_enroller_training + +# # create smaller JFA object +# jfa2 = bob.bio.gmm.algorithm.JFA( +# number_of_gaussians=2, +# subspace_dimension_of_u=2, +# subspace_dimension_of_v=2, +# kmeans_training_iterations=1, +# gmm_training_iterations=1, +# jfa_training_iterations=1, +# INIT_SEED=seed_value, +# ) + +# train_data = utils.random_training_set( +# (100, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# # reference is the same as for GMM projection +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/gmm_projector.hdf5" +# ) +# try: +# # train the projector +# jfa2.train_projector(train_data, temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# jfa1.load_projector(reference_file) +# jfa2.load_projector(temp_file) + +# assert jfa1.ubm.is_similar_to(jfa2.ubm) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # generate and project random feature +# feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) +# projected = jfa1.project(feature) +# assert isinstance(projected, bob.learn.em.GMMStats) +# _compare( +# projected, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5"), +# jfa1.write_feature, +# jfa1.read_feature, +# ) + +# # enroll model from random features +# random_features = utils.random_training_set_by_id( +# (20, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# train_data = [ +# [jfa1.project(feature) for feature in client_features] +# for client_features in random_features +# ] +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/jfa_enroller.hdf5" +# ) +# try: +# # train the projector +# jfa2.train_enroller(train_data, temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# jfa1.load_enroller(reference_file) +# jfa2.load_enroller(temp_file) + +# assert jfa1.jfa_base.is_similar_to(jfa2.jfa_base) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # enroll model from random features +# random_features = utils.random_training_set( +# (20, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# enroll_features = [jfa1.project(feature) for feature in random_features] +# model = jfa1.enroll(enroll_features) +# assert isinstance(model, bob.learn.em.JFAMachine) +# _compare( +# model, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/jfa_model.hdf5"), +# jfa1.write_model, +# jfa1.read_model, +# ) + +# # compare model with probe +# probe = jfa1.read_feature( +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") +# ) +# reference_score = 0.02225812 +# assert ( +# abs(jfa1.score(model, probe) - reference_score) < 1e-5 +# ), "The scores differ: %3.8f, %3.8f" % (jfa1.score(model, probe), reference_score) +# # TODO: implement that +# # assert abs(jfa1.score_for_multiple_probes(model, [probe, probe]) - reference_score) < 1e-5, jfa1.score_for_multiple_probes(model, [probe, probe]) + + +# def test_ivector_cosine(): +# temp_file = bob.io.base.test_utils.temporary_filename() +# ivec1 = bob.bio.base.load_resource( +# "ivector-cosine", "algorithm", preferred_package="bob.bio.gmm" +# ) +# assert isinstance(ivec1, bob.bio.gmm.algorithm.IVector) +# assert isinstance(ivec1, bob.bio.gmm.algorithm.GMM) +# assert isinstance(ivec1, bob.bio.base.algorithm.Algorithm) +# assert ivec1.performs_projection +# assert ivec1.requires_projector_training +# assert ivec1.use_projected_features_for_enrollment +# assert ivec1.split_training_features_by_client +# assert not ivec1.requires_enroller_training + +# # create smaller IVector object +# ivec2 = bob.bio.gmm.algorithm.IVector( +# number_of_gaussians=2, +# subspace_dimension_of_t=2, +# kmeans_training_iterations=1, +# tv_training_iterations=1, +# INIT_SEED=seed_value, +# ) + +# train_data = utils.random_training_set( +# (100, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# train_data = [train_data] + +# # reference is the same as for GMM projection +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector_projector.hdf5" +# ) +# try: +# # train the projector + +# ivec2.train_projector(train_data, temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# ivec1.load_projector(reference_file) +# ivec2.load_projector(temp_file) + +# assert ivec1.ubm.is_similar_to(ivec2.ubm) +# assert ivec1.tv.is_similar_to(ivec2.tv) +# assert ivec1.whitener.is_similar_to(ivec2.whitener) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # generate and project random feature +# feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) +# projected = ivec1.project(feature) +# _compare( +# projected, +# pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector_projected.hdf5" +# ), +# ivec1.write_feature, +# ivec1.read_feature, +# ) + +# # enroll model from random features +# random_features = utils.random_training_set( +# (20, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# enroll_features = [ivec1.project(feature) for feature in random_features] +# model = ivec1.enroll(enroll_features) +# _compare( +# model, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/ivector_model.hdf5"), +# ivec1.write_model, +# ivec1.read_model, +# ) + +# # compare model with probe +# probe = ivec1.read_feature( +# pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector_projected.hdf5" +# ) +# ) +# reference_score = -0.00187151 +# assert ( +# abs(ivec1.score(model, probe) - reference_score) < 1e-5 +# ), "The scores differ: %3.8f, %3.8f" % (ivec1.score(model, probe), reference_score) +# # TODO: implement that +# assert ( +# abs(ivec1.score_for_multiple_probes(model, [probe, probe]) - reference_score) +# < 1e-5 +# ) + + +# def test_ivector_plda(): +# temp_file = bob.io.base.test_utils.temporary_filename() +# ivec1 = bob.bio.base.load_resource( +# "ivector-plda", "algorithm", preferred_package="bob.bio.gmm" +# ) +# ivec1.use_plda = True + +# # create smaller IVector object +# ivec2 = bob.bio.gmm.algorithm.IVector( +# number_of_gaussians=2, +# subspace_dimension_of_t=10, +# kmeans_training_iterations=1, +# tv_training_iterations=1, +# INIT_SEED=seed_value, +# use_plda=True, +# plda_dim_F=2, +# plda_dim_G=2, +# plda_training_iterations=2, +# ) + +# train_data = utils.random_training_set_by_id( +# (100, 45), count=5, minimum=-5.0, maximum=5.0 +# ) + +# # reference is the same as for GMM projection +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector2_projector.hdf5" +# ) +# try: +# # train the projector + +# ivec2.train_projector(train_data, temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# ivec1.load_projector(reference_file) +# ivec2.load_projector(temp_file) + +# assert ivec1.ubm.is_similar_to(ivec2.ubm) +# assert ivec1.tv.is_similar_to(ivec2.tv) +# assert ivec1.whitener.is_similar_to(ivec2.whitener) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # generate and project random feature +# feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) +# projected = ivec1.project(feature) +# _compare( +# projected, +# pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector2_projected.hdf5" +# ), +# ivec1.write_feature, +# ivec1.read_feature, +# ) + +# # enroll model from random features +# random_features = utils.random_training_set( +# (20, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# enroll_features = [ivec1.project(feature) for feature in random_features] + +# model = ivec1.enroll(enroll_features) +# _compare( +# model, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/ivector2_model.hdf5"), +# ivec1.write_model, +# ivec1.read_model, +# ) + +# # compare model with probe +# probe = ivec1.read_feature( +# pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector2_projected.hdf5" +# ) +# ) +# logger.info("%f" % ivec1.score(model, probe)) +# reference_score = 1.21879822 +# assert ( +# abs(ivec1.score(model, probe) - reference_score) < 1e-5 +# ), "The scores differ: %3.8f, %3.8f" % (ivec1.score(model, probe), reference_score) +# assert ( +# abs(ivec1.score_for_multiple_probes(model, [probe, probe]) - reference_score) +# < 1e-5 +# ) + + +# def test_ivector_lda_wccn_plda(): +# temp_file = bob.io.base.test_utils.temporary_filename() +# ivec1 = bob.bio.base.load_resource( +# "ivector-lda-wccn-plda", "algorithm", preferred_package="bob.bio.gmm" +# ) +# ivec1.use_lda = True +# ivec1.use_wccn = True +# ivec1.use_plda = True +# # create smaller IVector object +# ivec2 = bob.bio.gmm.algorithm.IVector( +# number_of_gaussians=2, +# subspace_dimension_of_t=10, +# kmeans_training_iterations=1, +# tv_training_iterations=1, +# INIT_SEED=seed_value, +# use_lda=True, +# lda_dim=3, +# use_wccn=True, +# use_plda=True, +# plda_dim_F=2, +# plda_dim_G=2, +# plda_training_iterations=2, +# ) + +# train_data = utils.random_training_set_by_id( +# (100, 45), count=5, minimum=-5.0, maximum=5.0 +# ) + +# # reference is the same as for GMM projection +# reference_file = pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector3_projector.hdf5" +# ) +# try: +# # train the projector + +# ivec2.train_projector(train_data, temp_file) + +# assert os.path.exists(temp_file) + +# if regenerate_refs: +# shutil.copy(temp_file, reference_file) + +# # check projection matrix +# ivec1.load_projector(reference_file) +# ivec2.load_projector(temp_file) + +# assert ivec1.ubm.is_similar_to(ivec2.ubm) +# assert ivec1.tv.is_similar_to(ivec2.tv) +# assert ivec1.whitener.is_similar_to(ivec2.whitener) +# finally: +# if os.path.exists(temp_file): +# os.remove(temp_file) + +# # generate and project random feature +# feature = utils.random_array((20, 45), -5.0, 5.0, seed=84) +# projected = ivec1.project(feature) +# _compare( +# projected, +# pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector3_projected.hdf5" +# ), +# ivec1.write_feature, +# ivec1.read_feature, +# ) + +# # enroll model from random features +# random_features = utils.random_training_set( +# (20, 45), count=5, minimum=-5.0, maximum=5.0 +# ) +# enroll_features = [ivec1.project(feature) for feature in random_features] +# model = ivec1.enroll(enroll_features) +# _compare( +# model, +# pkg_resources.resource_filename("bob.bio.gmm.test", "data/ivector3_model.hdf5"), +# ivec1.write_model, +# ivec1.read_model, +# ) + +# # compare model with probe +# probe = ivec1.read_feature( +# pkg_resources.resource_filename( +# "bob.bio.gmm.test", "data/ivector3_projected.hdf5" +# ) +# ) +# reference_score = 0.2954148598 +# assert ( +# abs(ivec1.score(model, probe) - reference_score) < 1e-5 +# ), "The scores differ: %3.8f, %3.8f" % (ivec1.score(model, probe), reference_score) +# assert ( +# abs(ivec1.score_for_multiple_probes(model, [probe, probe]) - reference_score) +# < 1e-5 +# ) diff --git a/setup.py b/setup.py index a69ee95..3b512f4 100644 --- a/setup.py +++ b/setup.py @@ -110,10 +110,6 @@ setup( "ivector-plda = bob.bio.gmm.config.algorithm.ivector_plda:algorithm", "ivector-lda-wccn-plda = bob.bio.gmm.config.algorithm.ivector_lda_wccn_plda:algorithm", ], - "bob.bio.bioalgorithm": [ - "gmm = bob.bio.gmm.config.bioalgorithm.gmm:bioalgorithm", - "gmm-regular = bob.bio.gmm.config.bioalgorithm.gmm_regular:bioalgorithm", - ], }, # Classifiers are important if you plan to distribute this package through # PyPI. You can find the complete list of classifiers that are valid and -- GitLab