diff --git a/bob/learn/misc/cpp/KMeansTrainer.cpp b/bob/learn/misc/cpp/KMeansTrainer.cpp index 092c87d48698d46e4a45d1ac03a0c4337b5ed04a..583543f7bbc1b249a435cc3242fabbf14c251e84 100644 --- a/bob/learn/misc/cpp/KMeansTrainer.cpp +++ b/bob/learn/misc/cpp/KMeansTrainer.cpp @@ -12,43 +12,73 @@ #include <boost/random.hpp> +/* bob::learn::misc::KMeansTrainer::KMeansTrainer(double convergence_threshold, - size_t max_iterations, bool compute_likelihood, InitializationMethod i_m): - bob::learn::misc::EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >( - convergence_threshold, max_iterations, compute_likelihood), - m_initialization_method(i_m), - m_rng(new boost::mt19937()), m_average_min_distance(0), - m_zeroethOrderStats(0), m_firstOrderStats(0,0) + size_t max_iterations, bool compute_likelihood, InitializationMethod i_m) { + + m_initialization_method = i_m; + m_zeroethOrderStats = 0; + m_firstOrderStats = 0; + m_average_min_distance = 0; + + m_compute_likelihood = compute_likelihood; + m_convergence_threshold = convergence_threshold; + m_max_iterations = max_iterations; + //m_rng(new boost::mt19937()); + } +*/ -bob::learn::misc::KMeansTrainer::KMeansTrainer(const bob::learn::misc::KMeansTrainer& other): - bob::learn::misc::EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >( - other.m_convergence_threshold, other.m_max_iterations, other.m_compute_likelihood), - m_initialization_method(other.m_initialization_method), - m_rng(other.m_rng), m_average_min_distance(other.m_average_min_distance), - m_zeroethOrderStats(bob::core::array::ccopy(other.m_zeroethOrderStats)), - m_firstOrderStats(bob::core::array::ccopy(other.m_firstOrderStats)) +bob::learn::misc::KMeansTrainer::KMeansTrainer(InitializationMethod i_m) { + + m_initialization_method = i_m; + m_zeroethOrderStats = 0; + m_firstOrderStats = 0; + m_average_min_distance = 0; + + //m_rng(new boost::mt19937()); +} + + +bob::learn::misc::KMeansTrainer::KMeansTrainer(const bob::learn::misc::KMeansTrainer& other){ + + //m_convergence_threshold = other.m_convergence_threshold; + //m_max_iterations = other.m_max_iterations; + //m_compute_likelihood = other.m_compute_likelihood; + + m_initialization_method = other.m_initialization_method; + m_rng = other.m_rng; + m_average_min_distance = other.m_average_min_distance; + m_zeroethOrderStats = bob::core::array::ccopy(other.m_zeroethOrderStats); + m_firstOrderStats = bob::core::array::ccopy(other.m_firstOrderStats); } + bob::learn::misc::KMeansTrainer& bob::learn::misc::KMeansTrainer::operator= (const bob::learn::misc::KMeansTrainer& other) { if(this != &other) { - EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >::operator=(other); - m_initialization_method = other.m_initialization_method; - m_rng = other.m_rng; - m_average_min_distance = other.m_average_min_distance; - m_zeroethOrderStats.reference(bob::core::array::ccopy(other.m_zeroethOrderStats)); - m_firstOrderStats.reference(bob::core::array::ccopy(other.m_firstOrderStats)); + //m_compute_likelihood = other.m_compute_likelihood; + //m_convergence_threshold = other.m_convergence_threshold; + //m_max_iterations = other.m_max_iterations; + m_rng = other.m_rng; + m_initialization_method = other.m_initialization_method; + m_average_min_distance = other.m_average_min_distance; + + m_zeroethOrderStats = bob::core::array::ccopy(other.m_zeroethOrderStats); + m_firstOrderStats = bob::core::array::ccopy(other.m_firstOrderStats); } return *this; } + bool bob::learn::misc::KMeansTrainer::operator==(const bob::learn::misc::KMeansTrainer& b) const { - return EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> >::operator==(b) && + return //m_compute_likelihood == b.m_compute_likelihood && + //m_convergence_threshold == b.m_convergence_threshold && + //m_max_iterations == b.m_max_iterations && m_initialization_method == b.m_initialization_method && *m_rng == *(b.m_rng) && m_average_min_distance == b.m_average_min_distance && bob::core::array::hasSameShape(m_zeroethOrderStats, b.m_zeroethOrderStats) && @@ -97,7 +127,7 @@ void bob::learn::misc::KMeansTrainer::initialize(bob::learn::misc::KMeansMachine bool valid = true; for(size_t j=0; j<i && valid; ++j) { - kmeans.getMean(j, cur_mean); + cur_mean = kmeans.getMean(j); valid = blitz::any(mean != cur_mean); } // if different, stop otherwise, try with another one @@ -204,10 +234,6 @@ double bob::learn::misc::KMeansTrainer::computeLikelihood(bob::learn::misc::KMea return m_average_min_distance; } -void bob::learn::misc::KMeansTrainer::finalize(bob::learn::misc::KMeansMachine& kmeans, - const blitz::Array<double,2>& ar) -{ -} bool bob::learn::misc::KMeansTrainer::resetAccumulators(bob::learn::misc::KMeansMachine& kmeans) { diff --git a/bob/learn/misc/include/bob.learn.misc/KMeansTrainer.h b/bob/learn/misc/include/bob.learn.misc/KMeansTrainer.h index 110b82ae16a5304890f5a86958978998b7567f43..caba9f64a1c2df62f7f71d344cdaa280d9af152b 100644 --- a/bob/learn/misc/include/bob.learn.misc/KMeansTrainer.h +++ b/bob/learn/misc/include/bob.learn.misc/KMeansTrainer.h @@ -9,8 +9,8 @@ #define BOB_LEARN_MISC_KMEANSTRAINER_H #include <bob.learn.misc/KMeansMachine.h> -#include <bob.learn.misc/EMTrainer.h> #include <boost/version.hpp> +#include <boost/random/mersenne_twister.hpp> namespace bob { namespace learn { namespace misc { @@ -20,7 +20,7 @@ namespace bob { namespace learn { namespace misc { * @details See Section 9.1 of Bishop, "Pattern recognition and machine learning", 2006 * It uses a random initialisation of the means followed by the expectation-maximization algorithm */ -class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Array<double,2> > +class KMeansTrainer { public: /** @@ -40,9 +40,13 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar /** * @brief Constructor */ + KMeansTrainer(InitializationMethod=RANDOM); + + /* KMeansTrainer(double convergence_threshold=0.001, size_t max_iterations=10, bool compute_likelihood=true, - InitializationMethod=RANDOM); + InitializationMethod=RANDOM);*/ + /** * @brief Virtualize destructor @@ -85,7 +89,7 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar * Data is split into as many chunks as there are means, * then each mean is set to a random example within each chunk. */ - virtual void initialize(bob::learn::misc::KMeansMachine& kMeansMachine, + void initialize(bob::learn::misc::KMeansMachine& kMeansMachine, const blitz::Array<double,2>& sampler); /** @@ -94,25 +98,21 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar * - average (Square Euclidean) distance from the closest mean * Implements EMTrainer::eStep(double &) */ - virtual void eStep(bob::learn::misc::KMeansMachine& kmeans, + void eStep(bob::learn::misc::KMeansMachine& kmeans, const blitz::Array<double,2>& data); /** * @brief Updates the mean based on the statistics from the E-step. */ - virtual void mStep(bob::learn::misc::KMeansMachine& kmeans, + void mStep(bob::learn::misc::KMeansMachine& kmeans, const blitz::Array<double,2>&); /** * @brief This functions returns the average min (Square Euclidean) * distance (average distance to the closest mean) */ - virtual double computeLikelihood(bob::learn::misc::KMeansMachine& kmeans); + double computeLikelihood(bob::learn::misc::KMeansMachine& kmeans); - /** - * @brief Function called at the end of the training - */ - virtual void finalize(bob::learn::misc::KMeansMachine& kMeansMachine, const blitz::Array<double,2>& sampler); /** * @brief Reset the statistics accumulators @@ -156,7 +156,13 @@ class KMeansTrainer: public EMTrainer<bob::learn::misc::KMeansMachine, blitz::Ar void setAverageMinDistance(const double value) { m_average_min_distance = value; } - protected: + private: + + //bool m_compute_likelihood; ///< whether lilelihood is computed during the EM loop or not + //double m_convergence_threshold; ///< convergence threshold + //size_t m_max_iterations; ///< maximum number of EM iterations + + /** * @brief The initialization method * Check that there is no duplicated means during the random initialization