Skip to content
Snippets Groups Projects
Commit fb3d75e8 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'fix-gsvd' into 'master'

Fixed a bug when P=<N. Added a test case to fix it.

See merge request !8
parents 60253d78 d32b8bbd
Branches
Tags
1 merge request!8Fixed a bug when P=<N. Added a test case to fix it.
Pipeline #
......@@ -225,9 +225,10 @@ void bob::math::gsvd( blitz::Array<double,2>& A,
}
// B - diag(C) part. Here the C is LxL
// Swaping
bob::math::swap_(C_1d, iwork.get(), K, std::min(M,r));
blitz::Array<double,2> C_diag (L,L); C_diag = 0;
bob::math::diag(C_1d(blitz::Range(0,L-1)), C_diag);
bob::math::diag(C_1d(blitz::Range(K,K+L-1)), C_diag);
C(blitz::Range(K, M-1), blitz::Range(K, K+L-1)) = C_diag;
//2.2 Preparing S
......@@ -239,7 +240,7 @@ void bob::math::gsvd( blitz::Array<double,2>& A,
// Swap
bob::math::swap_(S_1d, iwork.get(), K, std::min(M,r));
blitz::Array<double,2> S_diag (L,L); S_diag = 0;
bob::math::diag(S_1d(blitz::Range(0,L-1)), S_diag);
bob::math::diag(S_1d(blitz::Range(K,K+L-1)), S_diag);
S(blitz::Range(0, L-1), blitz::Range(K, K+L-1)) = S_diag;
}
......
......@@ -65,7 +65,7 @@ void gsvd(blitz::Array<double,2>& A,
void swap_(blitz::Array<T,1>& A, int* indexes, int begin, int end) {
T aux = 0;
int fortran_index = 0;
for (int i=0; i<A.extent(0); i++){
for (int i=begin; i<A.extent(0); i++){
fortran_index = indexes[i]-1;
aux = A(i);
A(i) = A(fortran_index);
......@@ -88,9 +88,12 @@ void gsvd(blitz::Array<double,2>& A,
int fortran_index = 0;
for (int i=begin; i<end; i++){
fortran_index = indexes[i]-1;
aux = A(blitz::Range::all(), i);
A(blitz::Range::all(), i) = A(blitz::Range::all(), fortran_index);
A(blitz::Range::all(), fortran_index) = aux;
if (fortran_index < A.extent(1)){
aux = A(blitz::Range::all(), i);
A(blitz::Range::all(), i) = A(blitz::Range::all(), fortran_index);
A(blitz::Range::all(), fortran_index) = aux;
}
}
}
......
......@@ -38,7 +38,6 @@ def gsvd_relations(A,B):
B_check = numpy.dot(numpy.dot(X,S.T),V.T).T
nose.tools.eq_( (abs(B-B_check) < 1e-10).all(), True )
del U,V,X,C,S
def test_first_case():
......@@ -65,3 +64,15 @@ def test_second_case():
gsvd_relations(A, B)
def test_corner_case():
"""
Testing when P <= N.
"""
A = numpy.random.rand(25, 25)
B = numpy.random.rand(25, 25)
gsvd_relations(A, B)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment