Skip to content
Snippets Groups Projects

Fix the scoring functions and adapt parameters to bob.learn.em

2 files
+ 6
32
Compare changes
  • Side-by-side
  • Inline

Files

+ 5
25
@@ -23,7 +23,7 @@ import numpy as np
from h5py import File as HDF5File
from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine
@@ -213,8 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
relevance_factor=self.enroll_relevance_factor,
alpha=self.enroll_alpha,
map_relevance_factor=self.enroll_relevance_factor,
map_alpha=self.enroll_alpha,
)
gmm.fit(array)
return gmm
@@ -244,7 +244,6 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the model.
"""
logger.debug(f"scoring {biometric_reference}, {probe}")
if not isinstance(probe, GMMStats):
# Projection is done here instead of in transform (or it would be applied to enrollment data too...)
probe = self.project(probe)
@@ -253,7 +252,7 @@ class GMM(BioAlgorithm, BaseEstimator):
ubm=self.ubm,
test_stats=probe,
frame_length_normalization=True,
)[0, 0]
)[0]
def score_multiple_biometric_references(
self, biometric_references: "list[GMMMachine]", probe: GMMStats
@@ -270,32 +269,13 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models.
"""
logger.debug(f"scoring {biometric_references}, {probe}")
assert isinstance(biometric_references[0], GMMMachine), type(
biometric_references[0]
)
stats = self.project(probe) if not isinstance(probe, GMMStats) else probe
return self.scoring_function(
models_means=biometric_references,
ubm=self.ubm,
test_stats=stats,
frame_length_normalization=True,
).reshape((-1,))
def score_for_multiple_probes(self, biometric_reference, probes):
"""This function computes the score between the given model and several given probe files."""
logger.debug(f"scoring {biometric_reference}, {probes}")
assert isinstance(biometric_reference, GMMMachine)
stats = [
self.project(probe) if not isinstance(probe, GMMStats) else probe
for probe in probes
]
return self.scoring_function(
models_means=biometric_reference.means,
ubm=self.ubm,
test_stats=stats,
frame_length_normalization=True,
).reshape((-1,))
)
def fit(self, X, y=None, **kwargs):
"""Trains the UBM."""
Loading