Skip to content
Snippets Groups Projects
Commit eddf8eb5 authored by Manuel Günther's avatar Manuel Günther
Browse files

Now use checked functions from bob::math

parent ccc91ef4
No related branches found
No related tags found
1 merge request!7Now use checked functions from bob::math
Pipeline #
...@@ -127,7 +127,7 @@ namespace bob { namespace learn { namespace linear { ...@@ -127,7 +127,7 @@ namespace bob { namespace learn { namespace linear {
blitz::Array<double,1> preMean(n_features); blitz::Array<double,1> preMean(n_features);
blitz::Array<double,2> Sw(n_features, n_features); blitz::Array<double,2> Sw(n_features, n_features);
blitz::Array<double,2> Sb(n_features, n_features); blitz::Array<double,2> Sb(n_features, n_features);
bob::math::scatters_(data, Sw, Sb, preMean); bob::math::scatters(data, Sw, Sb, preMean);
// computes the generalized eigenvalue decomposition // computes the generalized eigenvalue decomposition
// so to find the eigen vectors/values of Sw^(-1) * Sb // so to find the eigen vectors/values of Sw^(-1) * Sb
...@@ -137,11 +137,11 @@ namespace bob { namespace learn { namespace linear { ...@@ -137,11 +137,11 @@ namespace bob { namespace learn { namespace linear {
if (m_use_pinv) { if (m_use_pinv) {
//note: misuse V and Sw as temporary place holders for data //note: misuse V and Sw as temporary place holders for data
bob::math::pinv_(Sw, V); //V now contains Sw^-1 bob::math::pinv(Sw, V); //V now contains Sw^-1
bob::math::prod_(V, Sb, Sw); //Sw now contains Sw^-1*Sb bob::math::prod(V, Sb, Sw); //Sw now contains Sw^-1*Sb
blitz::Array<std::complex<double>,1> Dtemp(eigen_values_.shape()); blitz::Array<std::complex<double>,1> Dtemp(eigen_values_.shape());
blitz::Array<std::complex<double>,2> Vtemp(V.shape()); blitz::Array<std::complex<double>,2> Vtemp(V.shape());
bob::math::eig_(Sw, Vtemp, Dtemp); //V now contains eigen-vectors bob::math::eig(Sw, Vtemp, Dtemp); //V now contains eigen-vectors
//sorting: we know this problem on has real eigen-values //sorting: we know this problem on has real eigen-values
blitz::Range a = blitz::Range::all(); blitz::Range a = blitz::Range::all();
...@@ -153,7 +153,7 @@ namespace bob { namespace learn { namespace linear { ...@@ -153,7 +153,7 @@ namespace bob { namespace learn { namespace linear {
} }
} }
else { else {
bob::math::eigSym_(Sb, Sw, V, eigen_values_); bob::math::eigSym(Sb, Sw, V, eigen_values_);
} }
// Convert ascending order to descending order // Convert ascending order to descending order
......
...@@ -155,7 +155,7 @@ namespace bob { namespace learn { namespace linear { ...@@ -155,7 +155,7 @@ namespace bob { namespace learn { namespace linear {
void Machine::forward_ (const blitz::Array<double,1>& input, blitz::Array<double,1>& output) const { void Machine::forward_ (const blitz::Array<double,1>& input, blitz::Array<double,1>& output) const {
m_buffer = (input - m_input_sub) / m_input_div; m_buffer = (input - m_input_sub) / m_input_div;
bob::math::prod_(m_buffer, m_weight, output); bob::math::prod(m_buffer, m_weight, output);
for (int i=0; i<m_weight.extent(1); ++i) for (int i=0; i<m_weight.extent(1); ++i)
output(i) = m_activation->f(output(i) + m_bias(i)); output(i) = m_activation->f(output(i) + m_bias(i));
......
...@@ -68,12 +68,12 @@ namespace bob { namespace learn { namespace linear { ...@@ -68,12 +68,12 @@ namespace bob { namespace learn { namespace linear {
*/ */
blitz::Array<double,1> mean(X.extent(1)); blitz::Array<double,1> mean(X.extent(1));
blitz::Array<double,2> Sigma(X.extent(1), X.extent(1)); blitz::Array<double,2> Sigma(X.extent(1), X.extent(1));
bob::math::scatter_(X, Sigma, mean); bob::math::scatter(X, Sigma, mean);
Sigma /= (X.extent(0)-1); //unbiased variance estimator Sigma /= (X.extent(0)-1); //unbiased variance estimator
blitz::Array<double,2> U(X.extent(1), X.extent(1)); blitz::Array<double,2> U(X.extent(1), X.extent(1));
blitz::Array<double,1> e(X.extent(1)); blitz::Array<double,1> e(X.extent(1));
bob::math::eigSym_(Sigma, U, e); bob::math::eigSym(Sigma, U, e);
e.reverseSelf(0); e.reverseSelf(0);
U.reverseSelf(1); U.reverseSelf(1);
...@@ -123,7 +123,7 @@ namespace bob { namespace learn { namespace linear { ...@@ -123,7 +123,7 @@ namespace bob { namespace learn { namespace linear {
const int rank_1 = (rank == (int)X.extent(1))? X.extent(1) : X.extent(0); const int rank_1 = (rank == (int)X.extent(1))? X.extent(1) : X.extent(0);
blitz::Array<double,2> U(X.extent(1), rank_1); blitz::Array<double,2> U(X.extent(1), rank_1);
blitz::Array<double,1> sigma(rank_1); blitz::Array<double,1> sigma(rank_1);
bob::math::svd_(data, U, sigma, safe_svd); bob::math::svd(data, U, sigma, safe_svd);
/** /**
* sets the linear machine with the results: * sets the linear machine with the results:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment