Skip to content
Snippets Groups Projects
Commit 152684d6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'gmm-params' into 'master'

Fix the scoring functions and adapt parameters to bob.learn.em

See merge request !29
parents ccbf6317 b5ff4115
No related branches found
No related tags found
1 merge request!29Fix the scoring functions and adapt parameters to bob.learn.em
Pipeline #58887 failed
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from h5py import File as HDF5File from h5py import File as HDF5File
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import BioAlgorithm from bob.bio.base.pipelines.vanilla_biometrics import BioAlgorithm
from bob.learn.em import GMMMachine from bob.learn.em import GMMMachine
from bob.learn.em import GMMStats from bob.learn.em import GMMStats
from bob.learn.em import KMeansMachine from bob.learn.em import KMeansMachine
...@@ -213,8 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -213,8 +213,8 @@ class GMM(BioAlgorithm, BaseEstimator):
update_variances=self.enroll_update_variances, update_variances=self.enroll_update_variances,
update_weights=self.enroll_update_weights, update_weights=self.enroll_update_weights,
mean_var_update_threshold=self.variance_threshold, mean_var_update_threshold=self.variance_threshold,
relevance_factor=self.enroll_relevance_factor, map_relevance_factor=self.enroll_relevance_factor,
alpha=self.enroll_alpha, map_alpha=self.enroll_alpha,
) )
gmm.fit(array) gmm.fit(array)
return gmm return gmm
...@@ -244,7 +244,6 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -244,7 +244,6 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the model. The probe data to compare to the model.
""" """
logger.debug(f"scoring {biometric_reference}, {probe}")
if not isinstance(probe, GMMStats): if not isinstance(probe, GMMStats):
# Projection is done here instead of in transform (or it would be applied to enrollment data too...) # Projection is done here instead of in transform (or it would be applied to enrollment data too...)
probe = self.project(probe) probe = self.project(probe)
...@@ -253,7 +252,7 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -253,7 +252,7 @@ class GMM(BioAlgorithm, BaseEstimator):
ubm=self.ubm, ubm=self.ubm,
test_stats=probe, test_stats=probe,
frame_length_normalization=True, frame_length_normalization=True,
)[0, 0] )[0]
def score_multiple_biometric_references( def score_multiple_biometric_references(
self, biometric_references: "list[GMMMachine]", probe: GMMStats self, biometric_references: "list[GMMMachine]", probe: GMMStats
...@@ -270,32 +269,13 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -270,32 +269,13 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models. The probe data to compare to the models.
""" """
logger.debug(f"scoring {biometric_references}, {probe}")
assert isinstance(biometric_references[0], GMMMachine), type(
biometric_references[0]
)
stats = self.project(probe) if not isinstance(probe, GMMStats) else probe stats = self.project(probe) if not isinstance(probe, GMMStats) else probe
return self.scoring_function( return self.scoring_function(
models_means=biometric_references, models_means=biometric_references,
ubm=self.ubm, ubm=self.ubm,
test_stats=stats, test_stats=stats,
frame_length_normalization=True, frame_length_normalization=True,
).reshape((-1,)) )
def score_for_multiple_probes(self, biometric_reference, probes):
"""This function computes the score between the given model and several given probe files."""
logger.debug(f"scoring {biometric_reference}, {probes}")
assert isinstance(biometric_reference, GMMMachine)
stats = [
self.project(probe) if not isinstance(probe, GMMStats) else probe
for probe in probes
]
return self.scoring_function(
models_means=biometric_reference.means,
ubm=self.ubm,
test_stats=stats,
frame_length_normalization=True,
).reshape((-1,))
def fit(self, X, y=None, **kwargs): def fit(self, X, y=None, **kwargs):
"""Trains the UBM.""" """Trains the UBM."""
......
...@@ -160,16 +160,10 @@ def test_score(): ...@@ -160,16 +160,10 @@ def test_score():
gmm1.score(biometric_reference, probe), reference_score, decimal=5 gmm1.score(biometric_reference, probe), reference_score, decimal=5
) )
multi_probes = gmm1.score_for_multiple_probes(
biometric_reference, [probe, probe, probe]
)
assert multi_probes.shape == (3,), multi_probes.shape
numpy.testing.assert_almost_equal(multi_probes, reference_score, decimal=5)
multi_refs = gmm1.score_multiple_biometric_references( multi_refs = gmm1.score_multiple_biometric_references(
[biometric_reference, biometric_reference, biometric_reference], probe [biometric_reference, biometric_reference, biometric_reference], probe
) )
assert multi_refs.shape == (3,), multi_refs.shape assert multi_refs.shape == (3, 1), multi_refs.shape
numpy.testing.assert_almost_equal(multi_refs, reference_score, decimal=5) numpy.testing.assert_almost_equal(multi_refs, reference_score, decimal=5)
# With not projected data # With not projected data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment