Skip to content
Snippets Groups Projects
Commit 70aa8f72 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

fix plda tests

parent 8d70e55c
Branches
Tags
No related merge requests found
Pipeline #55597 passed
...@@ -180,7 +180,7 @@ def test_lda(): ...@@ -180,7 +180,7 @@ def test_lda():
# enroll model from random features # enroll model from random features
enroll = utils.random_training_set(5, 5, 0., 255., seed=21) enroll = utils.random_training_set(5, 5, 0., 255., seed=21)
model = lda1.enroll(enroll) model = lda1.enroll(enroll)
_compare(model, pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_model.hdf5'), lda1.write_model, lda1.read_model) _compare(model, pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_model.hdf5'), lda1.write_model, lda1.read_model)
# compare model with probe # compare model with probe
probe = lda1.read_feature(pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_projected.hdf5')) probe = lda1.read_feature(pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_projected.hdf5'))
reference_score = -233.30450012 reference_score = -233.30450012
...@@ -335,8 +335,13 @@ def test_plda(): ...@@ -335,8 +335,13 @@ def test_plda():
plda1.load_enroller(reference_file) plda1.load_enroller(reference_file)
plda3.load_enroller(temp_file) plda3.load_enroller(temp_file)
assert plda1.pca_machine.is_similar_to(plda3.pca_machine) numpy.testing.assert_array_almost_equal(plda1.pca_machine.weights, plda3.pca_machine.weights)
assert plda1.plda_base.is_similar_to(plda3.plda_base) numpy.testing.assert_array_almost_equal(plda1.pca_machine.biases, plda3.pca_machine.biases)
numpy.testing.assert_array_almost_equal(abs(plda1.plda_base.f), abs(plda3.plda_base.f))
numpy.testing.assert_array_almost_equal(plda1.plda_base.g, plda3.plda_base.g)
numpy.testing.assert_array_almost_equal(plda1.plda_base.mu, plda3.plda_base.mu)
numpy.testing.assert_array_almost_equal(plda1.plda_base.sigma, plda3.plda_base.sigma)
numpy.testing.assert_array_almost_equal(plda1.plda_base.variance_threshold, plda3.plda_base.variance_threshold)
finally: finally:
if os.path.exists(temp_file): os.remove(temp_file) if os.path.exists(temp_file): os.remove(temp_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment