Skip to content
Snippets Groups Projects
Commit d7c5f69c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Removed the inheritance with the EMTrainer and removed some private variables...

Removed the inheritance with the EMTrainer and removed some private variables in for the new binding. Still in progress, but it is an initial version.
parent f2bafa71
No related branches found
No related tags found
No related merge requests found
...@@ -12,43 +12,73 @@ ...@@ -12,43 +12,73 @@
#include <boost/random.hpp> #include <boost/random.hpp>
/*
bob::learn::misc::KMeansTrainer::KMeansTrainer(double convergence_threshold, bob::learn::misc::KMeansTrainer::KMeansTrainer(double convergence_threshold,
size_t max_iterations, bool compute_likelihood, InitializationMethod i_m): size_t max_iterations, bool compute_likelihood, InitializationMethod i_m)
bob::learn::misc::EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >(
convergence_threshold, max_iterations, compute_likelihood),
m_initialization_method(i_m),
m_rng(new boost::mt19937()), m_average_min_distance(0),
m_zeroethOrderStats(0), m_firstOrderStats(0,0)
{ {
m_initialization_method = i_m;
m_zeroethOrderStats = 0;
m_firstOrderStats = 0;
m_average_min_distance = 0;
m_compute_likelihood = compute_likelihood;
m_convergence_threshold = convergence_threshold;
m_max_iterations = max_iterations;
//m_rng(new boost::mt19937());
} }
*/
bob::learn::misc::KMeansTrainer::KMeansTrainer(const bob::learn::misc::KMeansTrainer& other): bob::learn::misc::KMeansTrainer::KMeansTrainer(InitializationMethod i_m)
bob::learn::misc::EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >(
other.m_convergence_threshold, other.m_max_iterations, other.m_compute_likelihood),
m_initialization_method(other.m_initialization_method),
m_rng(other.m_rng), m_average_min_distance(other.m_average_min_distance),
m_zeroethOrderStats(bob::core::array::ccopy(other.m_zeroethOrderStats)),
m_firstOrderStats(bob::core::array::ccopy(other.m_firstOrderStats))
{ {
m_initialization_method = i_m;
m_zeroethOrderStats = 0;
m_firstOrderStats = 0;
m_average_min_distance = 0;
//m_rng(new boost::mt19937());
}
bob::learn::misc::KMeansTrainer::KMeansTrainer(const bob::learn::misc::KMeansTrainer& other){
//m_convergence_threshold = other.m_convergence_threshold;
//m_max_iterations = other.m_max_iterations;
//m_compute_likelihood = other.m_compute_likelihood;
m_initialization_method = other.m_initialization_method;
m_rng = other.m_rng;
m_average_min_distance = other.m_average_min_distance;
m_zeroethOrderStats = bob::core::array::ccopy(other.m_zeroethOrderStats);
m_firstOrderStats = bob::core::array::ccopy(other.m_firstOrderStats);
} }
bob::learn::misc::KMeansTrainer& bob::learn::misc::KMeansTrainer::operator= bob::learn::misc::KMeansTrainer& bob::learn::misc::KMeansTrainer::operator=
(const bob::learn::misc::KMeansTrainer& other) (const bob::learn::misc::KMeansTrainer& other)
{ {
if(this != &other) if(this != &other)
{ {
EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >::operator=(other); //m_compute_likelihood = other.m_compute_likelihood;
m_initialization_method = other.m_initialization_method; //m_convergence_threshold = other.m_convergence_threshold;
m_rng = other.m_rng; //m_max_iterations = other.m_max_iterations;
m_average_min_distance = other.m_average_min_distance; m_rng = other.m_rng;
m_zeroethOrderStats.reference(bob::core::array::ccopy(other.m_zeroethOrderStats)); m_initialization_method = other.m_initialization_method;
m_firstOrderStats.reference(bob::core::array::ccopy(other.m_firstOrderStats)); m_average_min_distance = other.m_average_min_distance;
m_zeroethOrderStats = bob::core::array::ccopy(other.m_zeroethOrderStats);
m_firstOrderStats = bob::core::array::ccopy(other.m_firstOrderStats);
} }
return *this; return *this;
} }
bool bob::learn::misc::KMeansTrainer::operator==(const bob::learn::misc::KMeansTrainer& b) const { bool bob::learn::misc::KMeansTrainer::operator==(const bob::learn::misc::KMeansTrainer& b) const {
return EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >::operator==(b) && return //m_compute_likelihood == b.m_compute_likelihood &&
//m_convergence_threshold == b.m_convergence_threshold &&
//m_max_iterations == b.m_max_iterations &&
m_initialization_method == b.m_initialization_method && m_initialization_method == b.m_initialization_method &&
*m_rng == *(b.m_rng) && m_average_min_distance == b.m_average_min_distance && *m_rng == *(b.m_rng) && m_average_min_distance == b.m_average_min_distance &&
bob::core::array::hasSameShape(m_zeroethOrderStats, b.m_zeroethOrderStats) && bob::core::array::hasSameShape(m_zeroethOrderStats, b.m_zeroethOrderStats) &&
...@@ -97,7 +127,7 @@ void bob::learn::misc::KMeansTrainer::initialize(bob::learn::misc::KMeansMachine ...@@ -97,7 +127,7 @@ void bob::learn::misc::KMeansTrainer::initialize(bob::learn::misc::KMeansMachine
bool valid = true; bool valid = true;
for(size_t j=0; j<i && valid; ++j) for(size_t j=0; j<i && valid; ++j)
{ {
kmeans.getMean(j, cur_mean); cur_mean = kmeans.getMean(j);
valid = blitz::any(mean != cur_mean); valid = blitz::any(mean != cur_mean);
} }
// if different, stop otherwise, try with another one // if different, stop otherwise, try with another one
...@@ -204,10 +234,6 @@ double bob::learn::misc::KMeansTrainer::computeLikelihood(bob::learn::misc::KMea ...@@ -204,10 +234,6 @@ double bob::learn::misc::KMeansTrainer::computeLikelihood(bob::learn::misc::KMea
return m_average_min_distance; return m_average_min_distance;
} }
void bob::learn::misc::KMeansTrainer::finalize(bob::learn::misc::KMeansMachine& kmeans,
const blitz::Array<double,2>& ar)
{
}
bool bob::learn::misc::KMeansTrainer::resetAccumulators(bob::learn::misc::KMeansMachine& kmeans) bool bob::learn::misc::KMeansTrainer::resetAccumulators(bob::learn::misc::KMeansMachine& kmeans)
{ {
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#define BOB_LEARN_MISC_KMEANSTRAINER_H #define BOB_LEARN_MISC_KMEANSTRAINER_H
#include <bob.learn.misc/KMeansMachine.h> #include <bob.learn.misc/KMeansMachine.h>
#include <bob.learn.misc/EMTrainer.h>
#include <boost/version.hpp> #include <boost/version.hpp>
#include <boost/random/mersenne_twister.hpp>
namespace bob { namespace learn { namespace misc { namespace bob { namespace learn { namespace misc {
...@@ -20,7 +20,7 @@ namespace bob { namespace learn { namespace misc { ...@@ -20,7 +20,7 @@ namespace bob { namespace learn { namespace misc {
* @details See Section 9.1 of Bishop, "Pattern recognition and machine learning", 2006 * @details See Section 9.1 of Bishop, "Pattern recognition and machine learning", 2006
* It uses a random initialisation of the means followed by the expectation-maximization algorithm * It uses a random initialisation of the means followed by the expectation-maximization algorithm
*/ */
class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> > class KMeansTrainer
{ {
public: public:
/** /**
...@@ -40,9 +40,13 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar ...@@ -40,9 +40,13 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar
/** /**
* @brief Constructor * @brief Constructor
*/ */
KMeansTrainer(InitializationMethod=RANDOM);
/*
KMeansTrainer(double convergence_threshold=0.001, KMeansTrainer(double convergence_threshold=0.001,
size_t max_iterations=10, bool compute_likelihood=true, size_t max_iterations=10, bool compute_likelihood=true,
InitializationMethod=RANDOM); InitializationMethod=RANDOM);*/
/** /**
* @brief Virtualize destructor * @brief Virtualize destructor
...@@ -85,7 +89,7 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar ...@@ -85,7 +89,7 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar
* Data is split into as many chunks as there are means, * Data is split into as many chunks as there are means,
* then each mean is set to a random example within each chunk. * then each mean is set to a random example within each chunk.
*/ */
virtual void initialize(bob::learn::misc::KMeansMachine& kMeansMachine, void initialize(bob::learn::misc::KMeansMachine& kMeansMachine,
const blitz::Array<double,2>& sampler); const blitz::Array<double,2>& sampler);
/** /**
...@@ -94,25 +98,21 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar ...@@ -94,25 +98,21 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar
* - average (Square Euclidean) distance from the closest mean * - average (Square Euclidean) distance from the closest mean
* Implements EMTrainer::eStep(double &) * Implements EMTrainer::eStep(double &)
*/ */
virtual void eStep(bob::learn::misc::KMeansMachine& kmeans, void eStep(bob::learn::misc::KMeansMachine& kmeans,
const blitz::Array<double,2>& data); const blitz::Array<double,2>& data);
/** /**
* @brief Updates the mean based on the statistics from the E-step. * @brief Updates the mean based on the statistics from the E-step.
*/ */
virtual void mStep(bob::learn::misc::KMeansMachine& kmeans, void mStep(bob::learn::misc::KMeansMachine& kmeans,
const blitz::Array<double,2>&); const blitz::Array<double,2>&);
/** /**
* @brief This functions returns the average min (Square Euclidean) * @brief This functions returns the average min (Square Euclidean)
* distance (average distance to the closest mean) * distance (average distance to the closest mean)
*/ */
virtual double computeLikelihood(bob::learn::misc::KMeansMachine& kmeans); double computeLikelihood(bob::learn::misc::KMeansMachine& kmeans);
/**
* @brief Function called at the end of the training
*/
virtual void finalize(bob::learn::misc::KMeansMachine& kMeansMachine, const blitz::Array<double,2>& sampler);
/** /**
* @brief Reset the statistics accumulators * @brief Reset the statistics accumulators
...@@ -156,7 +156,13 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar ...@@ -156,7 +156,13 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar
void setAverageMinDistance(const double value) { m_average_min_distance = value; } void setAverageMinDistance(const double value) { m_average_min_distance = value; }
protected: private:
//bool m_compute_likelihood; ///< whether lilelihood is computed during the EM loop or not
//double m_convergence_threshold; ///< convergence threshold
//size_t m_max_iterations; ///< maximum number of EM iterations
/** /**
* @brief The initialization method * @brief The initialization method
* Check that there is no duplicated means during the random initialization * Check that there is no duplicated means during the random initialization
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment