Skip to content
Snippets Groups Projects

[gmm] small fixes to make sure the gmm algorithm runs

Merged Amir MOHAMMADI requested to merge fixes into master
All threads resolved!
+ 26
24
@@ -16,7 +16,6 @@ import logging
@@ -16,7 +16,6 @@ import logging
from typing import Callable
from typing import Callable
from typing import Union
from typing import Union
import dask
import dask.array as da
import dask.array as da
import numpy as np
import numpy as np
@@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator):
@@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator):
scoring_function: Callable = linear_scoring,
scoring_function: Callable = linear_scoring,
# RNG
# RNG
init_seed: int = 5489,
init_seed: int = 5489,
 
**kwargs,
):
):
"""Initializes the local UBM-GMM tool chain.
"""Initializes the local UBM-GMM tool chain.
@@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator):
@@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator):
self.ubm = None
self.ubm = None
super().__init__()
super().__init__(**kwargs)
def _check_feature(self, feature):
def _check_feature(self, feature):
"""Checks that the features are appropriate"""
"""Checks that the features are appropriate"""
@@ -196,27 +196,28 @@ class GMM(BioAlgorithm, BaseEstimator):
@@ -196,27 +196,28 @@ class GMM(BioAlgorithm, BaseEstimator):
Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
"""
"""
[self._check_feature(feature) for feature in data]
for feature in data:
array = da.vstack(data)
self._check_feature(feature)
 
 
data = np.vstack(data)
# Use the array to train a GMM and return it
# Use the array to train a GMM and return it
logger.info("Enrolling with %d feature vectors", array.shape[0])
logger.info("Enrolling with %d feature vectors", data.shape[0])
with dask.config.set(scheduler="threads"):
gmm = GMMMachine(
gmm = GMMMachine(
n_gaussians=self.number_of_gaussians,
n_gaussians=self.number_of_gaussians,
trainer="map",
trainer="map",
ubm=copy.deepcopy(self.ubm),
ubm=copy.deepcopy(self.ubm),
convergence_threshold=self.training_threshold,
convergence_threshold=self.training_threshold,
max_fitting_steps=self.gmm_enroll_iterations,
max_fitting_steps=self.gmm_enroll_iterations,
random_state=self.rng,
random_state=self.rng,
update_means=self.enroll_update_means,
update_means=self.enroll_update_means,
update_variances=self.enroll_update_variances,
update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights,
update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
mean_var_update_threshold=self.variance_threshold,
map_relevance_factor=self.enroll_relevance_factor,
map_relevance_factor=self.enroll_relevance_factor,
map_alpha=self.enroll_alpha,
map_alpha=self.enroll_alpha,
)
)
gmm.fit(data)
gmm.fit(array)
return gmm
return gmm
def read_biometric_reference(self, model_file):
def read_biometric_reference(self, model_file):
@@ -277,10 +278,11 @@ class GMM(BioAlgorithm, BaseEstimator):
@@ -277,10 +278,11 @@ class GMM(BioAlgorithm, BaseEstimator):
frame_length_normalization=True,
frame_length_normalization=True,
)
)
def fit(self, X, y=None, **kwargs):
def fit(self, array, y=None, **kwargs):
"""Trains the UBM."""
"""Trains the UBM."""
# Stack all the samples in a 2D array of features
# Stack all the samples in a 2D array of features
array = da.vstack(X).persist()
if isinstance(array, da.Array):
 
array = array.persist()
logger.debug("UBM with %d feature vectors", array.shape[0])
logger.debug("UBM with %d feature vectors", array.shape[0])
@@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator):
@@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator):
# Train the GMM
# Train the GMM
logger.info("Training UBM GMM")
logger.info("Training UBM GMM")
self.ubm.fit(array, ubm_train=True)
self.ubm.fit(array)
return self
return self
Loading