Skip to content
Snippets Groups Projects
Commit b5ff4115 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Fix the scoring functions and enroll parameters

parent ccbf6317
No related branches found
No related tags found
1 merge request!29Fix the scoring functions and adapt parameters to bob.learn.em
Pipeline #58870 failed
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from h5py import File as HDF5File from h5py import File as HDF5File
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm
from bob.learn.em import GMMMachine from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine from bob.learn.em import KMeansMachine
...@@ -213,8 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -213,8 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances, update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights, update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold, mean_var_update_threshold=self.variance_threshold,
relevance_factor=self.enroll_relevance_factor, map_relevance_factor=self.enroll_relevance_factor,
alpha=self.enroll_alpha, map_alpha=self.enroll_alpha,
) )
gmm.fit(array) gmm.fit(array)
return gmm return gmm
...@@ -244,7 +244,6 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -244,7 +244,6 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the model. The probe data to compare to the model.
""" """
logger.debug(f"scoring {biometric_reference}, {probe}")
if not isinstance(probe, GMMStats): if not isinstance(probe, GMMStats):
# Projection is done here instead of in transform (or it would be applied to enrollment data too...) # Projection is done here instead of in transform (or it would be applied to enrollment data too...)
probe = self.project(probe) probe = self.project(probe)
...@@ -253,7 +252,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -253,7 +252,7 @@ class GMM(BioAlgorithm, BaseEstimator):
ubm=self.ubm, ubm=self.ubm,
test_stats=probe, test_stats=probe,
frame_length_normalization=True, frame_length_normalization=True,
)[0, 0] )[0]
def score_multiple_biometric_references( def score_multiple_biometric_references(
self, biometric_references: "list[GMMMachine]", probe: GMMStats self, biometric_references: "list[GMMMachine]", probe: GMMStats
...@@ -270,32 +269,13 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -270,32 +269,13 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models. The probe data to compare to the models.
""" """
logger.debug(f"scoring {biometric_references}, {probe}")
assert isinstance(biometric_references[0], GMMMachine), type(
biometric_references[0]
)
stats = self.project(probe) if not isinstance(probe, GMMStats) else probe stats = self.project(probe) if not isinstance(probe, GMMStats) else probe
return self.scoring_function( return self.scoring_function(
models_means=biometric_references, models_means=biometric_references,
ubm=self.ubm, ubm=self.ubm,
test_stats=stats, test_stats=stats,
frame_length_normalization=True, frame_length_normalization=True,
).reshape((-1,)) )
def score_for_multiple_probes(self, biometric_reference, probes):
"""This function computes the score between the given model and several given probe files."""
logger.debug(f"scoring {biometric_reference}, {probes}")
assert isinstance(biometric_reference, GMMMachine)
stats = [
self.project(probe) if not isinstance(probe, GMMStats) else probe
for probe in probes
]
return self.scoring_function(
models_means=biometric_reference.means,
ubm=self.ubm,
test_stats=stats,
frame_length_normalization=True,
).reshape((-1,))
def fit(self, X, y=None, **kwargs): def fit(self, X, y=None, **kwargs):
"""Trains the UBM.""" """Trains the UBM."""
......
...@@ -160,16 +160,10 @@ def test_score(): ...@@ -160,16 +160,10 @@ def test_score():
gmm1.score(biometric_reference, probe), reference_score, decimal=5 gmm1.score(biometric_reference, probe), reference_score, decimal=5
) )
multi_probes = gmm1.score_for_multiple_probes(
biometric_reference, [probe, probe, probe]
)
assert multi_probes.shape == (3,), multi_probes.shape
numpy.testing.assert_almost_equal(multi_probes, reference_score, decimal=5)
multi_refs = gmm1.score_multiple_biometric_references( multi_refs = gmm1.score_multiple_biometric_references(
[biometric_reference, biometric_reference, biometric_reference], probe [biometric_reference, biometric_reference, biometric_reference], probe
) )
assert multi_refs.shape == (3,), multi_refs.shape assert multi_refs.shape == (3, 1), multi_refs.shape
numpy.testing.assert_almost_equal(multi_refs, reference_score, decimal=5) numpy.testing.assert_almost_equal(multi_refs, reference_score, decimal=5)
# With not projected data # With not projected data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment