Skip to content
Snippets Groups Projects
Commit 8397399d authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Add C++ Roll/Unroll support

parent 7a07e9bf
Branches
Tags
No related merge requests found
...@@ -21,7 +21,7 @@ include_dirs = [ ...@@ -21,7 +21,7 @@ include_dirs = [
xbob.core.get_include(), xbob.core.get_include(),
] ]
packages = ['bob-io >= 1.2.2', 'bob-machine >= 1.2.2'] packages = ['bob-io >= 2.0.0a2']
version = '2.0.0a0' version = '2.0.0a0'
setup( setup(
...@@ -66,6 +66,7 @@ setup( ...@@ -66,6 +66,7 @@ setup(
"xbob/learn/mlp/rprop.cpp", "xbob/learn/mlp/rprop.cpp",
"xbob/learn/mlp/backprop.cpp", "xbob/learn/mlp/backprop.cpp",
"xbob/learn/mlp/trainer.cpp", "xbob/learn/mlp/trainer.cpp",
"xbob/learn/mlp/cxx/roll.cpp",
"xbob/learn/mlp/cxx/machine.cpp", "xbob/learn/mlp/cxx/machine.cpp",
"xbob/learn/mlp/cxx/cross_entropy.cpp", "xbob/learn/mlp/cxx/cross_entropy.cpp",
"xbob/learn/mlp/cxx/square_error.cpp", "xbob/learn/mlp/cxx/square_error.cpp",
......
/**
* @author Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
* @date Tue Jun 25 18:52:26 CEST 2013
*
* Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
*/
#include <bob/core/assert.h>
#include <xbob.learn.mlp/roll.h>
int bob::learn::mlp::detail::getNbParameters(const bob::learn::mlp::Machine& machine)
{
const std::vector<blitz::Array<double,1> >& b = machine.getBiases();
const std::vector<blitz::Array<double,2> >& w = machine.getWeights();
return bob::learn::mlp::detail::getNbParameters(w, b);
}
int bob::learn::mlp::detail::getNbParameters(
const std::vector<blitz::Array<double,2> >& w,
const std::vector<blitz::Array<double,1> >& b)
{
bob::core::array::assertSameDimensionLength(w.size(), b.size());
int N = 0;
for (int i=0; i<(int)w.size(); ++i)
N += b[i].numElements() + w[i].numElements();
return N;
}
void bob::learn::mlp::unroll(const bob::learn::mlp::Machine& machine,
blitz::Array<double,1>& vec)
{
const std::vector<blitz::Array<double,1> >& b = machine.getBiases();
const std::vector<blitz::Array<double,2> >& w = machine.getWeights();
unroll(w, b, vec);
}
void bob::learn::mlp::unroll(const std::vector<blitz::Array<double,2> >& w,
const std::vector<blitz::Array<double,1> >& b, blitz::Array<double,1>& vec)
{
// 1/ Check number of elements
const int N = bob::learn::mlp::detail::getNbParameters(w, b);
bob::core::array::assertSameDimensionLength(vec.extent(0), N);
// 2/ Roll
blitz::Range rall = blitz::Range::all();
int offset=0;
for (int i=0; i<(int)w.size(); ++i)
{
const int Nb = b[i].extent(0);
blitz::Range rb(offset,offset+Nb-1);
vec(rb) = b[i];
offset += Nb;
const int Nw0 = w[i].extent(0);
const int Nw1 = w[i].extent(1);
blitz::TinyVector<int,1> tv(Nw1);
for (int j=0; j<Nw0; ++j)
{
blitz::Range rw(offset,offset+Nw1-1);
vec(rw) = w[i](j,rall);
offset += Nw1;
}
}
}
void bob::learn::mlp::roll(bob::learn::mlp::Machine& machine,
const blitz::Array<double,1>& vec)
{
std::vector<blitz::Array<double,1> >& b = machine.updateBiases();
std::vector<blitz::Array<double,2> >& w = machine.updateWeights();
roll(w, b, vec);
}
void bob::learn::mlp::roll(std::vector<blitz::Array<double,2> >& w,
std::vector<blitz::Array<double,1> >& b, const blitz::Array<double,1>& vec)
{
// 1/ Check number of elements
const int N = bob::learn::mlp::detail::getNbParameters(w, b);
bob::core::array::assertSameDimensionLength(vec.extent(0), N);
// 2/ Roll
blitz::Range rall = blitz::Range::all();
int offset=0;
for (int i=0; i<(int)w.size(); ++i)
{
const int Nb = b[i].extent(0);
blitz::Array<double,1> vb = vec(blitz::Range(offset,offset+Nb-1));
b[i] = vb;
offset += Nb;
const int Nw0 = w[i].extent(0);
const int Nw1 = w[i].extent(1);
blitz::TinyVector<int,1> tv(Nw1);
for (int j=0; j<Nw0; ++j)
{
blitz::Array<double,1> vw = vec(blitz::Range(offset,offset+Nw1-1));
w[i](j,rall) = vw;
offset += Nw1;
}
}
}
/**
* @author Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
* @date Tue Jun 25 18:48:20 CEST 2013
*
* Copyright (C) 2011-2014 Idiap Research Institute, Martigny, Switzerland
*/
#ifndef BOB_LEARN_MLP_ROLL_H
#define BOB_LEARN_MLP_ROLL_H
#include <vector>
#include <blitz/array.h>
#include "machine.h"
namespace bob { namespace learn { namespace mlp {
namespace detail {
/**
* @brief Returns the number of parameters (weights and biases) in an
* MLP.
*/
int getNbParameters(const bob::learn::mlp::Machine& machine);
/**
* @brief Returns the number of parameters (weights and biases).
*/
int getNbParameters(const std::vector<blitz::Array<double,2> >& weights,
const std::vector<blitz::Array<double,1> >& biases);
}
/**
* @brief Puts the parameters (weights and biases) of the machine in a
* large single 1D vector
*/
void unroll(const bob::learn::mlp::Machine& machine, blitz::Array<double,1>& vec);
/**
* @brief Puts the parameters (weights and biases) in a large single 1D vector
*/
void unroll(const std::vector<blitz::Array<double,2> >& weights,
const std::vector<blitz::Array<double,1> >& biases,
blitz::Array<double,1>& vec);
/**
* @brief Sets the parameters (weights and biases) of the machine from a
* large single 1D vector
*/
void roll(bob::learn::mlp::Machine& machine, const blitz::Array<double,1>& vec);
/**
* @brief Sets the parameters (weights and biases) from a
* large single 1D vector
*/
void roll(std::vector<blitz::Array<double,2> >& weights,
std::vector<blitz::Array<double,1> >& biases,
const blitz::Array<double,1>& vec);
}}}
#endif /* BOB_LEARN_MLP_ROLL_H */
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment