Skip to content
Snippets Groups Projects
Commit 28aba742 authored by Manuel Günther's avatar Manuel Günther
Browse files

Moved initialization of cache of KMeansTrainer to where it belongs.

parent 2782bcab
No related branches found
No related tags found
No related merge requests found
...@@ -24,11 +24,11 @@ m_firstOrderStats(0) ...@@ -24,11 +24,11 @@ m_firstOrderStats(0)
bob::learn::em::KMeansTrainer::KMeansTrainer(const bob::learn::em::KMeansTrainer& other){ bob::learn::em::KMeansTrainer::KMeansTrainer(const bob::learn::em::KMeansTrainer& other){
m_initialization_method = other.m_initialization_method; m_initialization_method = other.m_initialization_method;
m_rng = other.m_rng; m_rng = other.m_rng;
m_average_min_distance = other.m_average_min_distance; m_average_min_distance = other.m_average_min_distance;
m_zeroethOrderStats = bob::core::array::ccopy(other.m_zeroethOrderStats); m_zeroethOrderStats = bob::core::array::ccopy(other.m_zeroethOrderStats);
m_firstOrderStats = bob::core::array::ccopy(other.m_firstOrderStats); m_firstOrderStats = bob::core::array::ccopy(other.m_firstOrderStats);
} }
...@@ -160,9 +160,6 @@ void bob::learn::em::KMeansTrainer::initialize(bob::learn::em::KMeansMachine& km ...@@ -160,9 +160,6 @@ void bob::learn::em::KMeansTrainer::initialize(bob::learn::em::KMeansMachine& km
kmeans.setMean(m, new_mean); kmeans.setMean(m, new_mean);
} }
} }
// Resize the accumulator
m_zeroethOrderStats.resize(kmeans.getNMeans());
m_firstOrderStats.resize(kmeans.getNMeans(), kmeans.getNInputs());
} }
void bob::learn::em::KMeansTrainer::eStep(bob::learn::em::KMeansMachine& kmeans, void bob::learn::em::KMeansTrainer::eStep(bob::learn::em::KMeansMachine& kmeans,
...@@ -206,12 +203,16 @@ double bob::learn::em::KMeansTrainer::computeLikelihood(bob::learn::em::KMeansMa ...@@ -206,12 +203,16 @@ double bob::learn::em::KMeansTrainer::computeLikelihood(bob::learn::em::KMeansMa
} }
bool bob::learn::em::KMeansTrainer::resetAccumulators(bob::learn::em::KMeansMachine& kmeans) void bob::learn::em::KMeansTrainer::resetAccumulators(bob::learn::em::KMeansMachine& kmeans)
{ {
// Resize the accumulator
m_zeroethOrderStats.resize(kmeans.getNMeans());
m_firstOrderStats.resize(kmeans.getNMeans(), kmeans.getNInputs());
// initialize with 0
m_average_min_distance = 0; m_average_min_distance = 0;
m_zeroethOrderStats = 0; m_zeroethOrderStats = 0;
m_firstOrderStats = 0; m_firstOrderStats = 0;
return true;
} }
void bob::learn::em::KMeansTrainer::setZeroethOrderStats(const blitz::Array<double,1>& zeroethOrderStats) void bob::learn::em::KMeansTrainer::setZeroethOrderStats(const blitz::Array<double,1>& zeroethOrderStats)
...@@ -225,4 +226,3 @@ void bob::learn::em::KMeansTrainer::setFirstOrderStats(const blitz::Array<double ...@@ -225,4 +226,3 @@ void bob::learn::em::KMeansTrainer::setFirstOrderStats(const blitz::Array<double
bob::core::array::assertSameShape(m_firstOrderStats, firstOrderStats); bob::core::array::assertSameShape(m_firstOrderStats, firstOrderStats);
m_firstOrderStats = firstOrderStats; m_firstOrderStats = firstOrderStats;
} }
...@@ -105,7 +105,7 @@ class KMeansTrainer ...@@ -105,7 +105,7 @@ class KMeansTrainer
* @brief Reset the statistics accumulators * @brief Reset the statistics accumulators
* to the correct size and a value of zero. * to the correct size and a value of zero.
*/ */
bool resetAccumulators(bob::learn::em::KMeansMachine& kMeansMachine); void resetAccumulators(bob::learn::em::KMeansMachine& kMeansMachine);
/** /**
* @brief Sets the Random Number Generator * @brief Sets the Random Number Generator
...@@ -144,7 +144,7 @@ class KMeansTrainer ...@@ -144,7 +144,7 @@ class KMeansTrainer
private: private:
/** /**
* @brief The initialization method * @brief The initialization method
* Check that there is no duplicated means during the random initialization * Check that there is no duplicated means during the random initialization
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment