Skip to content
Snippets Groups Projects
Commit acd09320 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[gmm] small fixes to make sure the gmm algorithm runs

parent c557079c
No related branches found
No related tags found
1 merge request!31[gmm] small fixes to make sure the gmm algorithm runs
Pipeline #59188 failed
...@@ -16,7 +16,6 @@ import logging ...@@ -16,7 +16,6 @@ import logging
from typing import Callable from typing import Callable
from typing import Union from typing import Union
import dask
import dask.array as da import dask.array as da
import numpy as np import numpy as np
...@@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator):
scoring_function: Callable = linear_scoring, scoring_function: Callable = linear_scoring,
# RNG # RNG
init_seed: int = 5489, init_seed: int = 5489,
**kwargs,
): ):
"""Initializes the local UBM-GMM tool chain. """Initializes the local UBM-GMM tool chain.
...@@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator):
self.ubm = None self.ubm = None
super().__init__() super().__init__(**kwargs)
def _check_feature(self, feature): def _check_feature(self, feature):
"""Checks that the features are appropriate""" """Checks that the features are appropriate"""
...@@ -196,27 +196,28 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -196,27 +196,28 @@ class GMM(BioAlgorithm, BaseEstimator):
Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data. Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
""" """
[self._check_feature(feature) for feature in data] for feature in data:
array = da.vstack(data) self._check_feature(feature)
data = np.vstack(data)
# Use the array to train a GMM and return it # Use the array to train a GMM and return it
logger.info("Enrolling with %d feature vectors", array.shape[0]) logger.info("Enrolling with %d feature vectors", data.shape[0])
with dask.config.set(scheduler="threads"): gmm = GMMMachine(
gmm = GMMMachine( n_gaussians=self.number_of_gaussians,
n_gaussians=self.number_of_gaussians, trainer="map",
trainer="map", ubm=copy.deepcopy(self.ubm),
ubm=copy.deepcopy(self.ubm), convergence_threshold=self.training_threshold,
convergence_threshold=self.training_threshold, max_fitting_steps=self.gmm_enroll_iterations,
max_fitting_steps=self.gmm_enroll_iterations, random_state=self.rng,
random_state=self.rng, update_means=self.enroll_update_means,
update_means=self.enroll_update_means, update_variances=self.enroll_update_variances,
update_variances=self.enroll_update_variances, update_weights=self.enroll_update_weights,
update_weights=self.enroll_update_weights, mean_var_update_threshold=self.variance_threshold,
mean_var_update_threshold=self.variance_threshold, map_relevance_factor=self.enroll_relevance_factor,
map_relevance_factor=self.enroll_relevance_factor, map_alpha=self.enroll_alpha,
map_alpha=self.enroll_alpha, )
) gmm.fit(data)
gmm.fit(array)
return gmm return gmm
def read_biometric_reference(self, model_file): def read_biometric_reference(self, model_file):
...@@ -277,10 +278,11 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -277,10 +278,11 @@ class GMM(BioAlgorithm, BaseEstimator):
frame_length_normalization=True, frame_length_normalization=True,
) )
def fit(self, X, y=None, **kwargs): def fit(self, array, y=None, **kwargs):
"""Trains the UBM.""" """Trains the UBM."""
# Stack all the samples in a 2D array of features # Stack all the samples in a 2D array of features
array = da.vstack(X).persist() if isinstance(array, da.Array):
array = array.persist()
logger.debug("UBM with %d feature vectors", array.shape[0]) logger.debug("UBM with %d feature vectors", array.shape[0])
...@@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator):
# Train the GMM # Train the GMM
logger.info("Training UBM GMM") logger.info("Training UBM GMM")
self.ubm.fit(array, ubm_train=True) self.ubm.fit(array)
return self return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment