Skip to content
Snippets Groups Projects
Commit b5ff4115 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Fix the scoring functions and enroll parameters

parent ccbf6317
Branches
No related tags found
1 merge request!29Fix the scoring functions and adapt parameters to bob.learn.em
Pipeline #58870 failed
......@@ -23,7 +23,7 @@ import numpy as np
from h5py import File as HDF5File
from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm
from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm
from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine
......@@ -213,8 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold,
relevance_factor=self.enroll_relevance_factor,
alpha=self.enroll_alpha,
map_relevance_factor=self.enroll_relevance_factor,
map_alpha=self.enroll_alpha,
)
gmm.fit(array)
return gmm
......@@ -244,7 +244,6 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the model.
"""
logger.debug(f"scoring {biometric_reference}, {probe}")
if not isinstance(probe, GMMStats):
# Projection is done here instead of in transform (or it would be applied to enrollment data too...)
probe = self.project(probe)
......@@ -253,7 +252,7 @@ class GMM(BioAlgorithm, BaseEstimator):
ubm=self.ubm,
test_stats=probe,
frame_length_normalization=True,
)[0, 0]
)[0]
def score_multiple_biometric_references(
self, biometric_references: "list[GMMMachine]", probe: GMMStats
......@@ -270,32 +269,13 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models.
"""
logger.debug(f"scoring {biometric_references}, {probe}")
assert isinstance(biometric_references[0], GMMMachine), type(
biometric_references[0]
)
stats = self.project(probe) if not isinstance(probe, GMMStats) else probe
return self.scoring_function(
models_means=biometric_references,
ubm=self.ubm,
test_stats=stats,
frame_length_normalization=True,
).reshape((-1,))
def score_for_multiple_probes(self, biometric_reference, probes):
"""This function computes the score between the given model and several given probe files."""
logger.debug(f"scoring {biometric_reference}, {probes}")
assert isinstance(biometric_reference, GMMMachine)
stats = [
self.project(probe) if not isinstance(probe, GMMStats) else probe
for probe in probes
]
return self.scoring_function(
models_means=biometric_reference.means,
ubm=self.ubm,
test_stats=stats,
frame_length_normalization=True,
).reshape((-1,))
)
def fit(self, X, y=None, **kwargs):
"""Trains the UBM."""
......
......@@ -160,16 +160,10 @@ def test_score():
gmm1.score(biometric_reference, probe), reference_score, decimal=5
)
multi_probes = gmm1.score_for_multiple_probes(
biometric_reference, [probe, probe, probe]
)
assert multi_probes.shape == (3,), multi_probes.shape
numpy.testing.assert_almost_equal(multi_probes, reference_score, decimal=5)
multi_refs = gmm1.score_multiple_biometric_references(
[biometric_reference, biometric_reference, biometric_reference], probe
)
assert multi_refs.shape == (3,), multi_refs.shape
assert multi_refs.shape == (3, 1), multi_refs.shape
numpy.testing.assert_almost_equal(multi_refs, reference_score, decimal=5)
# With not projected data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment