Skip to content
Snippets Groups Projects
Commit 103130f5 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

use numpy testing to get more accurate error messages

parent f1b8babc
No related branches found
No related tags found
1 merge request!20Update test_gsvd.py
Pipeline #54752 failed
...@@ -9,9 +9,9 @@ Tests GSVD ...@@ -9,9 +9,9 @@ Tests GSVD
Basically these tests test the GSVD relation. Basically these tests test the GSVD relation.
Given 2 matrices A and B GSVD(A,B) = [U,V,X,C,S] where, Given 2 matrices A and B GSVD(A,B) = [U,V,X,C,S] where,
A= (X * C.T * U^T)^T and A= (X * C.T * U^T)^T and
B= (X * S.T * V^T)^T and B= (X * S.T * V^T)^T and
C**2 + S**2 = 1 C**2 + S**2 = 1
""" """
...@@ -34,22 +34,22 @@ def gsvd_relations(A,B): ...@@ -34,22 +34,22 @@ def gsvd_relations(A,B):
A_check = numpy.dot(numpy.dot(X,C.T),U.T).T A_check = numpy.dot(numpy.dot(X,C.T),U.T).T
nose.tools.eq_( (abs(A-A_check) < 1e-10).all(), True ) nose.tools.eq_( (abs(A-A_check) < 1e-10).all(), True )
# Cheking the relation B= (X * S.T * V^T)^T # Cheking the relation B= (X * S.T * V^T)^T
B_check = numpy.dot(numpy.dot(X,S.T),V.T).T B_check = numpy.dot(numpy.dot(X,S.T),V.T).T
nose.tools.eq_( (abs(B-B_check) < 1e-10).all(), True ) nose.tools.eq_( (abs(B-B_check) < 1e-10).all(), True )
def svd_relations(A): def svd_relations(A):
[U, S, V] = bob.math.svd(A) [U, S, V] = bob.math.svd(A)
A_check = numpy.dot(numpy.dot(V,S), U) A_check = numpy.dot(numpy.dot(V,S), U)
nose.tools.eq_( (abs(A-A_check) < 1e-10).all(), True ) nose.tools.eq_( (abs(A-A_check) < 1e-10).all(), True )
def test_first_case(): def test_first_case():
#Testing the first scenario of gsvd: #Testing the first scenario of gsvd:
#M-K-L >= 0 (check http://www.netlib.org/lapack/explore-html/d1/d7e/group__double_g_esing_ga4a187519e5c71da3b3f67c85e9baf0f2.html#ga4a187519e5c71da3b3f67c85e9baf0f2) #M-K-L >= 0 (check http://www.netlib.org/lapack/explore-html/d1/d7e/group__double_g_esing_ga4a187519e5c71da3b3f67c85e9baf0f2.html#ga4a187519e5c71da3b3f67c85e9baf0f2)
A = numpy.random.rand(10,10) A = numpy.random.rand(10,10)
B = numpy.random.rand(790,10) B = numpy.random.rand(790,10)
...@@ -58,9 +58,9 @@ def test_first_case(): ...@@ -58,9 +58,9 @@ def test_first_case():
def test_second_case(): def test_second_case():
#Testing the second scenario of gsvd: #Testing the second scenario of gsvd:
#M-K-L < 0 (check http://www.netlib.org/lapack/explore-html/d1/d7e/group__double_g_esing_ga4a187519e5c71da3b3f67c85e9baf0f2.html#ga4a187519e5c71da3b3f67c85e9baf0f2) #M-K-L < 0 (check http://www.netlib.org/lapack/explore-html/d1/d7e/group__double_g_esing_ga4a187519e5c71da3b3f67c85e9baf0f2.html#ga4a187519e5c71da3b3f67c85e9baf0f2)
A = numpy.random.rand(4,5) A = numpy.random.rand(4,5)
...@@ -98,23 +98,23 @@ def test_svd_relation(): ...@@ -98,23 +98,23 @@ def test_svd_relation():
def test_svd_signal(): def test_svd_signal():
##Testing SVD signal ##Testing SVD signal
##This test was imported from bob.learn.linear ##This test was imported from bob.learn.linear
A = numpy.array([[3,-3,100], [4,-4,50], [3.5,-3.5,-50], [3.8,-3.7,-100]], dtype='float64') A = numpy.array([[3,-3,100], [4,-4,50], [3.5,-3.5,-50], [3.8,-3.7,-100]], dtype='float64')
U_ref = numpy.array([[ 2.20825004e-03, -1.80819459e-03, -9.99995927e-01], U_ref = numpy.array([[ 2.20825004e-03, -1.80819459e-03, -9.99995927e-01],
[ -7.09549949e-01, 7.04649416e-01, -2.84101853e-03], [ -7.09549949e-01, 7.04649416e-01, -2.84101853e-03],
[ 7.04651683e-01, 7.09553332e-01, 2.73037723e-04]]) [ 7.04651683e-01, 7.09553332e-01, 2.73037723e-04]])
[U,S,V] = bob.math.svd(A) [U,S,V] = bob.math.svd(A)
nose.tools.eq_((abs(U-U_ref) < 1e-7).all(), True) numpy.testing.assert_allclose(U, U_ref, rtol=1e-5, atol=1e-6)
svd_relations(A) svd_relations(A)
def test_svd_signal_book_example(): def test_svd_signal_book_example():
## Reference copied from here http://prod.sandia.gov/techlib/access-control.cgi/2007/076422.pdf ## Reference copied from here http://prod.sandia.gov/techlib/access-control.cgi/2007/076422.pdf
A = numpy.array([[2.5, 63.5, 40.1, 78, 61.1], A = numpy.array([[2.5, 63.5, 40.1, 78, 61.1],
[0.9, 58.0, 25.1, 78, 94.1], [0.9, 58.0, 25.1, 78, 94.1],
...@@ -130,6 +130,6 @@ def test_svd_signal_book_example(): ...@@ -130,6 +130,6 @@ def test_svd_signal_book_example():
[U,S,V] = bob.math.svd(A) [U,S,V] = bob.math.svd(A)
assert U[0,0] > 0 assert U[0,0] > 0
svd_relations(A) svd_relations(A)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment