Skip to content
Snippets Groups Projects
Commit 2ebdd357 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Multi-scores tests, references regen on latest ver

parent a451243a
No related branches found
No related tags found
1 merge request!26Python implementation of GMM
Pipeline #58364 failed
...@@ -230,9 +230,8 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -230,9 +230,8 @@ class GMM(BioAlgorithm, BaseEstimator):
logger.debug(f"scoring {biometric_reference}, {probe}") logger.debug(f"scoring {biometric_reference}, {probe}")
if not isinstance(probe, GMMStats): if not isinstance(probe, GMMStats):
probe = self.project( # Projection is done here instead of in transform (or it would be applied to enrollment data too...)
probe probe = self.project(probe)
) # Projection is done here instead of transform (or it would be applied to enrollment data too...)
return self.scoring_function( return self.scoring_function(
models_means=[biometric_reference], models_means=[biometric_reference],
ubm=self.ubm, ubm=self.ubm,
...@@ -265,26 +264,22 @@ class GMM(BioAlgorithm, BaseEstimator): ...@@ -265,26 +264,22 @@ class GMM(BioAlgorithm, BaseEstimator):
ubm=self.ubm, ubm=self.ubm,
test_stats=stats, test_stats=stats,
frame_length_normalization=True, frame_length_normalization=True,
) ).reshape((-1,))
def score_for_multiple_probes(self, model, probes): def score_for_multiple_probes(self, biometric_reference, probes):
"""This function computes the score between the given model and several given probe files.""" """This function computes the score between the given model and several given probe files."""
logger.debug(f"scoring {model}, {probes}") logger.debug(f"scoring {biometric_reference}, {probes}")
assert isinstance(model, GMMMachine) assert isinstance(biometric_reference, GMMMachine)
stats = [ stats = [
self.project(probe) if not isinstance(probe, GMMStats) else probe self.project(probe) if not isinstance(probe, GMMStats) else probe
for probe in probes for probe in probes
] ]
return ( return self.scoring_function(
self.scoring_function( models_means=biometric_reference.means,
models_means=model.means,
ubm=self.ubm, ubm=self.ubm,
test_stats=stats, test_stats=stats,
frame_length_normalization=True, frame_length_normalization=True,
) ).reshape((-1,))
.mean()
.reshape((-1,))
)
def fit(self, X, y=None, **kwargs): def fit(self, X, y=None, **kwargs):
"""Trains the UBM.""" """Trains the UBM."""
......
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -18,8 +18,10 @@ ...@@ -18,8 +18,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
import os
import tempfile import tempfile
import numpy
import pkg_resources import pkg_resources
import bob.bio.gmm import bob.bio.gmm
...@@ -46,6 +48,8 @@ def test_class(): ...@@ -46,6 +48,8 @@ def test_class():
gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm
) )
assert gmm1.number_of_gaussians == 512 assert gmm1.number_of_gaussians == 512
assert "bob_fit_supports_dask_array" in gmm1._get_tags()
assert gmm1.transform(None) is None
def test_training(): def test_training():
...@@ -125,11 +129,17 @@ def test_enroll(): ...@@ -125,11 +129,17 @@ def test_enroll():
"bob.bio.gmm.test", "data/gmm_enrolled.hdf5" "bob.bio.gmm.test", "data/gmm_enrolled.hdf5"
) )
if regenerate_refs: if regenerate_refs:
biometric_reference.save(reference_file) gmm1.write_biometric_reference(biometric_reference, reference_file)
# Compare to pre-generated file
gmm2 = gmm1.read_biometric_reference(reference_file) gmm2 = gmm1.read_biometric_reference(reference_file)
assert biometric_reference.is_similar_to(gmm2) assert biometric_reference.is_similar_to(gmm2)
with tempfile.NamedTemporaryFile(prefix="bob_", suffix="_bioref.hdf5") as fd:
temp_file = fd.name
gmm1.write_biometric_reference(biometric_reference, reference_file)
assert os.path.exists(temp_file)
def test_score(): def test_score():
gmm1 = GMM(number_of_gaussians=2) gmm1 = GMM(number_of_gaussians=2)
...@@ -143,18 +153,27 @@ def test_score(): ...@@ -143,18 +153,27 @@ def test_score():
probe = GMMStats.from_hdf5( probe = GMMStats.from_hdf5(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5") pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5")
) )
probe_data = utils.random_array((20, 45), -5.0, 5.0, seed=84)
reference_score = 0.045073 reference_score = -0.098980
assert (
abs(gmm1.score(biometric_reference, probe) - reference_score) < 1e-5 numpy.testing.assert_almost_equal(
), "The scores differ: %3.8f, %3.8f" % ( gmm1.score(biometric_reference, probe), reference_score, decimal=5
gmm1.score(biometric_reference, probe),
reference_score,
) )
assert (
abs( multi_probes = gmm1.score_for_multiple_probes(
gmm1.score_for_multiple_probes(biometric_reference, [probe, probe]) biometric_reference, [probe, probe, probe]
- reference_score
) )
< 1e-5 assert multi_probes.shape == (3,), multi_probes.shape
numpy.testing.assert_almost_equal(multi_probes, reference_score, decimal=5)
multi_refs = gmm1.score_multiple_biometric_references(
[biometric_reference, biometric_reference, biometric_reference], probe
)
assert multi_refs.shape == (3,), multi_refs.shape
numpy.testing.assert_almost_equal(multi_refs, reference_score, decimal=5)
# With not projected data
numpy.testing.assert_almost_equal(
gmm1.score(biometric_reference, probe_data), reference_score, decimal=5
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment