Skip to content
Snippets Groups Projects
Commit a0ec1e4f authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Clean GMM test, add write_model and multiple_probe

parent 95c678d9
Branches
No related tags found
1 merge request!26Python implementation of GMM
......@@ -268,7 +268,11 @@ class GMM(BioAlgorithm, BaseEstimator):
# Feature comparison #
def read_model(self, model_file):
"""Reads the model, which is a GMM machine"""
return GMMMachine.from_hdf5(bob.io.base.HDF5File(model_file))
return GMMMachine.from_hdf5(bob.io.base.HDF5File(model_file), ubm=self.ubm)
def write_model(self, model, model_file):
"""Write the features (GMM_Stats)"""
return model.save(model_file)
def score(self, biometric_reference: GMMMachine, data: GMMStats):
"""Computes the score for the given model and the given probe.
......@@ -283,13 +287,13 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the model.
"""
assert isinstance(biometric_reference, GMMMachine) # TODO is it a list?
assert isinstance(biometric_reference, GMMMachine)
assert isinstance(data, GMMStats)
return self.scoring_function(
models_means=[biometric_reference],
ubm=self.ubm,
test_stats=data,
frame_length_normalisation=True,
frame_length_normalization=True,
)[0, 0]
def score_multiple_biometric_references(
......@@ -307,26 +311,27 @@ class GMM(BioAlgorithm, BaseEstimator):
The probe data to compare to the models.
"""
assert isinstance(biometric_references, GMMMachine) # TODO is it a list?
assert isinstance(biometric_references, GMMMachine)
assert isinstance(data, GMMStats)
return self.scoring_function(
models_means=biometric_references,
ubm=self.ubm,
test_stats=data,
frame_length_normalisation=True,
frame_length_normalization=True,
)
# def score_for_multiple_probes(self, model, probes):
# """This function computes the score between the given model and several given probe files."""
# assert isinstance(model, GMMMachine)
# for probe in probes:
# assert isinstance(probe, GMMStats)
# # logger.warn("Please verify that this function is correct")
# return self.probe_fusion_function(
# self.scoring_function(
# model.means, self.ubm, probes, [], frame_length_normalisation=True
# )
# )
def score_for_multiple_probes(self, model, probes):
"""This function computes the score between the given model and several given probe files."""
assert isinstance(model, GMMMachine)
for probe in probes:
assert isinstance(probe, GMMStats)
# logger.warn("Please verify that this function is correct")
return self.scoring_function(
models_means=model.means,
ubm=self.ubm,
test_stats=probes,
frame_length_normalization=True,
).mean()
def fit(self, X, y=None, **kwargs):
"""Trains the UBM."""
......
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -24,7 +24,6 @@ import sys
import numpy
import pkg_resources
import pytest
import bob.bio.gmm
import bob.io.base
......@@ -35,7 +34,7 @@ from bob.bio.base.test import utils
logger = logging.getLogger(__name__)
regenerate_refs = True
regenerate_refs = False
seed_value = 5489
......@@ -73,11 +72,8 @@ def _compare_complex(
assert numpy.allclose(d, r, atol=1e-5)
@pytest.mark.isolated_gmm
def test_gmm():
temp_file = (
"./temptest/test_file" # TODO bob.io.base.test_utils.temporary_filename()
)
temp_file = bob.io.base.test_utils.temporary_filename()
gmm1 = bob.bio.base.load_resource(
"gmm", "bioalgorithm", preferred_package="bob.bio.gmm"
)
......@@ -85,11 +81,9 @@ def test_gmm():
assert isinstance(
gmm1, bob.bio.base.pipelines.vanilla_biometrics.abstract_classes.BioAlgorithm
)
# assert gmm1.performs_projection
# assert gmm1.requires_projector_training
# assert not gmm1.use_projected_features_for_enrollment
# assert not gmm1.split_training_features_by_client
# assert not gmm1.requires_enroller_training
# Fix the number of gaussians for tests
gmm1.number_of_gaussians = 2
# create smaller GMM object
gmm2 = bob.bio.gmm.bioalgorithm.GMM(
......@@ -149,7 +143,7 @@ def test_gmm():
probe = gmm1.read_feature(
pkg_resources.resource_filename("bob.bio.gmm.test", "data/gmm_projected.hdf5")
)
reference_score = -0.01676570
reference_score = -0.01992773
assert (
abs(gmm1.score(model, probe) - reference_score) < 1e-5
), "The scores differ: %3.8f, %3.8f" % (gmm1.score(model, probe), reference_score)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment