From 70aa8f72df5c85352a8d081f15c73a9c63aae10c Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Tue, 26 Oct 2021 17:07:50 +0200
Subject: [PATCH] fix plda tests

---
 bob/bio/base/test/test_algorithms.py | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/bob/bio/base/test/test_algorithms.py b/bob/bio/base/test/test_algorithms.py
index 45df58e2..4b31dd11 100644
--- a/bob/bio/base/test/test_algorithms.py
+++ b/bob/bio/base/test/test_algorithms.py
@@ -180,7 +180,7 @@ def test_lda():
   # enroll model from random features
   enroll = utils.random_training_set(5, 5, 0., 255., seed=21)
   model = lda1.enroll(enroll)
-  _compare(model, pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_model.hdf5'), lda1.write_model, lda1.read_model)  
+  _compare(model, pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_model.hdf5'), lda1.write_model, lda1.read_model)
   # compare model with probe
   probe = lda1.read_feature(pkg_resources.resource_filename('bob.bio.base.test', 'data/lda_projected.hdf5'))
   reference_score = -233.30450012
@@ -335,8 +335,13 @@ def test_plda():
     plda1.load_enroller(reference_file)
     plda3.load_enroller(temp_file)
 
-    assert plda1.pca_machine.is_similar_to(plda3.pca_machine)
-    assert plda1.plda_base.is_similar_to(plda3.plda_base)
+    numpy.testing.assert_array_almost_equal(plda1.pca_machine.weights, plda3.pca_machine.weights)
+    numpy.testing.assert_array_almost_equal(plda1.pca_machine.biases, plda3.pca_machine.biases)
+    numpy.testing.assert_array_almost_equal(abs(plda1.plda_base.f), abs(plda3.plda_base.f))
+    numpy.testing.assert_array_almost_equal(plda1.plda_base.g, plda3.plda_base.g)
+    numpy.testing.assert_array_almost_equal(plda1.plda_base.mu, plda3.plda_base.mu)
+    numpy.testing.assert_array_almost_equal(plda1.plda_base.sigma, plda3.plda_base.sigma)
+    numpy.testing.assert_array_almost_equal(plda1.plda_base.variance_threshold, plda3.plda_base.variance_threshold)
 
   finally:
     if os.path.exists(temp_file): os.remove(temp_file)
-- 
GitLab