Skip to content
Snippets Groups Projects

Fix GMM after the port from bob.bio.gmm

Merged Yannick DAYER requested to merge fix-gmm into master
2 unresolved threads
Files
3
@@ -159,7 +159,7 @@ class GMM(GMMMachine, BioAlgorithm):
gmm = GMMMachine(
n_gaussians=self.n_gaussians,
trainer="map",
ubm=copy.deepcopy(self.ubm),
ubm=copy.deepcopy(self),
convergence_threshold=self.convergence_threshold,
max_fitting_steps=self.enroll_iterations,
random_state=self.random_state,
@@ -175,11 +175,7 @@ class GMM(GMMMachine, BioAlgorithm):
def read_biometric_reference(self, model_file):
"""Reads an enrolled reference model, which is a MAP GMMMachine."""
if self.ubm is None:
raise ValueError(
"You must load a UBM before reading a biometric reference."
)
return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self.ubm)
return GMMMachine.from_hdf5(HDF5File(model_file, "r"), ubm=self)
def write_biometric_reference(self, model: GMMMachine, model_file):
"""Write the enrolled reference (MAP GMMMachine) into a file."""
@@ -201,7 +197,7 @@ class GMM(GMMMachine, BioAlgorithm):
probe = self.project(probe)
return self.scoring_function(
models_means=[biometric_reference],
ubm=self.ubm,
ubm=self,
test_stats=probe,
frame_length_normalization=True,
)[0]
@@ -224,7 +220,7 @@ class GMM(GMMMachine, BioAlgorithm):
stats = self.project(probe)
return self.scoring_function(
models_means=biometric_references,
ubm=self.ubm,
ubm=self,
test_stats=stats,
frame_length_normalization=True,
)
@@ -259,7 +255,7 @@ class GMM(GMMMachine, BioAlgorithm):
data.save(path)
def custom_enrolled_load_fn(self, path):
return GMMMachine.from_hdf5(path, ubm=self.ubm)
return GMMMachine.from_hdf5(path, ubm=self)
def _more_tags(self):
return {
Loading