Skip to content
Snippets Groups Projects
Commit acd09320 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[gmm] small fixes to make sure the gmm algorithm runs

parent c557079c
No related branches found
No related tags found
1 merge request!31[gmm] small fixes to make sure the gmm algorithm runs
Pipeline #59188 failed
......@@ -16,7 +16,6 @@ import logging
from typing import Callable
from typing import Union
import dask
import dask.array as da
import numpy as np
......@@ -71,6 +70,7 @@ class GMM(BioAlgorithm, BaseEstimator):
scoring_function: Callable = linear_scoring,
# RNG
init_seed: int = 5489,
**kwargs,
):
"""Initializes the local UBM-GMM tool chain.
......@@ -144,7 +144,7 @@ class GMM(BioAlgorithm, BaseEstimator):
self.ubm = None
super().__init__()
super().__init__(**kwargs)
def _check_feature(self, feature):
"""Checks that the features are appropriate"""
......@@ -196,12 +196,13 @@ class GMM(BioAlgorithm, BaseEstimator):
Returns a GMMMachine tuned from the UBM with MAP on a biometric reference data.
"""
[self._check_feature(feature) for feature in data]
array = da.vstack(data)
for feature in data:
self._check_feature(feature)
data = np.vstack(data)
# Use the array to train a GMM and return it
logger.info("Enrolling with %d feature vectors", array.shape[0])
logger.info("Enrolling with %d feature vectors", data.shape[0])
with dask.config.set(scheduler="threads"):
gmm = GMMMachine(
n_gaussians=self.number_of_gaussians,
trainer="map",
......@@ -216,7 +217,7 @@ class GMM(BioAlgorithm, BaseEstimator):
map_relevance_factor=self.enroll_relevance_factor,
map_alpha=self.enroll_alpha,
)
gmm.fit(array)
gmm.fit(data)
return gmm
def read_biometric_reference(self, model_file):
......@@ -277,10 +278,11 @@ class GMM(BioAlgorithm, BaseEstimator):
frame_length_normalization=True,
)
def fit(self, X, y=None, **kwargs):
def fit(self, array, y=None, **kwargs):
"""Trains the UBM."""
# Stack all the samples in a 2D array of features
array = da.vstack(X).persist()
if isinstance(array, da.Array):
array = array.persist()
logger.debug("UBM with %d feature vectors", array.shape[0])
......@@ -309,7 +311,7 @@ class GMM(BioAlgorithm, BaseEstimator):
# Train the GMM
logger.info("Training UBM GMM")
self.ubm.fit(array, ubm_train=True)
self.ubm.fit(array)
return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment