Skip to content
Snippets Groups Projects
Commit 84c1a36d authored by Manuel Günther's avatar Manuel Günther
Browse files

Made the GMMStats of the GMM training accessible.

parent 8619c0c1
No related branches found
No related tags found
No related merge requests found
......@@ -12,12 +12,14 @@
bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bool update_means,
const bool update_variances, const bool update_weights,
const double mean_var_update_responsibilities_threshold):
m_ss(new bob::learn::em::GMMStats()),
m_update_means(update_means), m_update_variances(update_variances),
m_update_weights(update_weights),
m_mean_var_update_responsibilities_threshold(mean_var_update_responsibilities_threshold)
{}
bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bob::learn::em::GMMBaseTrainer& b):
m_ss(new bob::learn::em::GMMStats()),
m_update_means(b.m_update_means), m_update_variances(b.m_update_variances),
m_mean_var_update_responsibilities_threshold(b.m_mean_var_update_responsibilities_threshold)
{}
......@@ -28,20 +30,20 @@ bob::learn::em::GMMBaseTrainer::~GMMBaseTrainer()
void bob::learn::em::GMMBaseTrainer::initialize(bob::learn::em::GMMMachine& gmm)
{
// Allocate memory for the sufficient statistics and initialise
m_ss.resize(gmm.getNGaussians(),gmm.getNInputs());
m_ss->resize(gmm.getNGaussians(),gmm.getNInputs());
}
void bob::learn::em::GMMBaseTrainer::eStep(bob::learn::em::GMMMachine& gmm,
const blitz::Array<double,2>& data)
{
m_ss.init();
m_ss->init();
// Calculate the sufficient statistics and save in m_ss
gmm.accStatistics(data, m_ss);
gmm.accStatistics(data, *m_ss);
}
double bob::learn::em::GMMBaseTrainer::computeLikelihood(bob::learn::em::GMMMachine& gmm)
{
return m_ss.log_likelihood / m_ss.T;
return m_ss->log_likelihood / m_ss->T;
}
......@@ -50,7 +52,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator=
{
if (this != &other)
{
m_ss = other.m_ss;
*m_ss = *other.m_ss;
m_update_means = other.m_update_means;
m_update_variances = other.m_update_variances;
m_update_weights = other.m_update_weights;
......@@ -62,7 +64,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator=
bool bob::learn::em::GMMBaseTrainer::operator==
(const bob::learn::em::GMMBaseTrainer &other) const
{
return m_ss == other.m_ss &&
return *m_ss == *other.m_ss &&
m_update_means == other.m_update_means &&
m_update_variances == other.m_update_variances &&
m_update_weights == other.m_update_weights &&
......@@ -79,7 +81,7 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to
(const bob::learn::em::GMMBaseTrainer &other, const double r_epsilon,
const double a_epsilon) const
{
return m_ss == other.m_ss &&
return *m_ss == *other.m_ss &&
m_update_means == other.m_update_means &&
m_update_variances == other.m_update_variances &&
m_update_weights == other.m_update_weights &&
......@@ -87,8 +89,8 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to
other.m_mean_var_update_responsibilities_threshold, r_epsilon, a_epsilon);
}
void bob::learn::em::GMMBaseTrainer::setGMMStats(const bob::learn::em::GMMStats& stats)
void bob::learn::em::GMMBaseTrainer::setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats)
{
bob::core::array::assertSameShape(m_ss.sumPx, stats.sumPx);
bob::core::array::assertSameShape(m_ss->sumPx, stats->sumPx);
m_ss = stats;
}
......@@ -14,8 +14,8 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(
const bool update_weights,
const double mean_var_update_responsibilities_threshold,
const bool reynolds_adaptation,
const double relevance_factor,
const bool reynolds_adaptation,
const double relevance_factor,
const double alpha,
boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm):
......@@ -33,7 +33,7 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(const bob::learn::em::MAP_GMMTrai
m_prior_gmm(b.m_prior_gmm)
{
m_relevance_factor = b.m_relevance_factor;
m_alpha = b.m_alpha;
m_alpha = b.m_alpha;
m_reynolds_adaptation = b.m_reynolds_adaptation;
}
......@@ -75,10 +75,10 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
{
// Read options and variables
double n_gaussians = gmm.getNGaussians();
//Checking if it is necessary to resize the cache
if((size_t)m_cache_alpha.extent(0) != n_gaussians)
initialize(gmm); //If it is different for some reason, there is no way, you have to initialize
initialize(gmm); //If it is different for some reason, there is no way, you have to initialize
// Check that the prior GMM has been specified
if (!m_prior_gmm)
......@@ -92,13 +92,13 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
if (!m_reynolds_adaptation)
m_cache_alpha = m_alpha;
else
m_cache_alpha = m_gmm_base_trainer.getGMMStats().n(i) / (m_gmm_base_trainer.getGMMStats().n(i) + m_relevance_factor);
m_cache_alpha = m_gmm_base_trainer.getGMMStats()->n(i) / (m_gmm_base_trainer.getGMMStats()->n(i) + m_relevance_factor);
// - Update weights if requested
// Equation 11 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000
if (m_gmm_base_trainer.getUpdateWeights()) {
// Calculate the maximum likelihood weights
m_cache_ml_weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
m_cache_ml_weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx
// Get the prior weights
const blitz::Array<double,1>& prior_weights = m_prior_gmm->getWeights();
......@@ -123,12 +123,12 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
for (size_t i=0; i<n_gaussians; ++i) {
const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean();
blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
means = prior_means;
}
else {
// Use the maximum likelihood means
means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means;
means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats()->sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i)) + (1-m_cache_alpha(i)) * prior_means;
}
}
}
......@@ -142,11 +142,11 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
const blitz::Array<double,1>& prior_variances = m_prior_gmm->getGaussian(i)->getVariance();
blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
variances = (prior_variances + prior_means) - blitz::pow2(means);
}
else {
variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats()->sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
}
gmm.getGaussian(i)->applyVarianceThresholds();
}
......@@ -200,4 +200,3 @@ bool bob::learn::em::MAP_GMMTrainer::is_similar_to
bob::core::isClose(m_alpha, other.m_alpha, r_epsilon, a_epsilon) &&
m_reynolds_adaptation == other.m_reynolds_adaptation;
}
......@@ -10,7 +10,7 @@
bob::learn::em::ML_GMMTrainer::ML_GMMTrainer(
const bool update_means,
const bool update_variances,
const bool update_variances,
const bool update_weights,
const double mean_var_update_responsibilities_threshold
):
......@@ -29,7 +29,7 @@ bob::learn::em::ML_GMMTrainer::~ML_GMMTrainer()
void bob::learn::em::ML_GMMTrainer::initialize(bob::learn::em::GMMMachine& gmm)
{
m_gmm_base_trainer.initialize(gmm);
// Allocate cache
size_t n_gaussians = gmm.getNGaussians();
m_cache_ss_n_thresholded.resize(n_gaussians);
......@@ -49,14 +49,14 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
// Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006
if (m_gmm_base_trainer.getUpdateWeights()) {
blitz::Array<double,1>& weights = gmm.updateWeights();
weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx
// Recompute the log weights in the cache of the GMMMachine
gmm.recomputeLogWeights();
}
// Generate a thresholded version of m_ss.n
for(size_t i=0; i<n_gaussians; ++i)
m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats().n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats()->n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
// Update GMM parameters using the sufficient statistics (m_ss)
// - Update means if requested
......@@ -64,7 +64,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
if (m_gmm_base_trainer.getUpdateMeans()) {
for(size_t i=0; i<n_gaussians; ++i) {
blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
means = m_gmm_base_trainer.getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
means = m_gmm_base_trainer.getGMMStats()->sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
}
}
......@@ -77,7 +77,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
for(size_t i=0; i<n_gaussians; ++i) {
const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean();
blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
variances = m_gmm_base_trainer.getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
variances = m_gmm_base_trainer.getGMMStats()->sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
gmm.getGaussian(i)->applyVarianceThresholds();
}
}
......
......@@ -30,7 +30,7 @@ class GMMBaseTrainer
* @brief Default constructor
*/
GMMBaseTrainer(const bool update_means=true,
const bool update_variances=false,
const bool update_variances=false,
const bool update_weights=false,
const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
......@@ -93,21 +93,21 @@ class GMMBaseTrainer
* @brief Returns the internal GMM statistics. Useful to parallelize the
* E-step
*/
const bob::learn::em::GMMStats getGMMStats() const
const boost::shared_ptr<bob::learn::em::GMMStats> getGMMStats() const
{ return m_ss; }
/**
* @brief Sets the internal GMM statistics. Useful to parallelize the
* E-step
*/
void setGMMStats(const bob::learn::em::GMMStats& stats);
void setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats);
/**
* update means on each iteration
*/
*/
bool getUpdateMeans()
{return m_update_means;}
/**
* update variances on each iteration
*/
......@@ -117,19 +117,19 @@ class GMMBaseTrainer
bool getUpdateWeights()
{return m_update_weights;}
double getMeanVarUpdateResponsibilitiesThreshold()
{return m_mean_var_update_responsibilities_threshold;}
private:
/**
* These are the sufficient statistics, calculated during the
* E-step and used during the M-step
*/
bob::learn::em::GMMStats m_ss;
boost::shared_ptr<bob::learn::em::GMMStats> m_ss;
/**
......
......@@ -28,11 +28,11 @@ class MAP_GMMTrainer
*/
MAP_GMMTrainer(
const bool update_means=true,
const bool update_variances=false,
const bool update_variances=false,
const bool update_weights=false,
const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(),
const bool reynolds_adaptation=false,
const double relevance_factor=4,
const bool reynolds_adaptation=false,
const double relevance_factor=4,
const double alpha=0.5,
boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm = boost::shared_ptr<bob::learn::em::GMMMachine>());
......@@ -108,14 +108,14 @@ class MAP_GMMTrainer
*/
double computeLikelihood(bob::learn::em::GMMMachine& gmm){
return m_gmm_base_trainer.computeLikelihood(gmm);
}
}
bool getReynoldsAdaptation()
{return m_reynolds_adaptation;}
void setReynoldsAdaptation(const bool reynolds_adaptation)
{m_reynolds_adaptation = reynolds_adaptation;}
double getRelevanceFactor()
{return m_relevance_factor;}
......@@ -130,6 +130,8 @@ class MAP_GMMTrainer
void setAlpha(const double alpha)
{m_alpha = alpha;}
bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;}
protected:
......@@ -140,7 +142,7 @@ class MAP_GMMTrainer
/**
Base Trainer for the MAP algorithm. Basically implements the e-step
*/
*/
bob::learn::em::GMMBaseTrainer m_gmm_base_trainer;
/**
......
......@@ -28,7 +28,7 @@ class ML_GMMTrainer{
* @brief Default constructor
*/
ML_GMMTrainer(const bool update_means=true,
const bool update_variances=false,
const bool update_variances=false,
const bool update_weights=false,
const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
......@@ -97,13 +97,15 @@ class ML_GMMTrainer{
*/
bool is_similar_to(const ML_GMMTrainer& b, const double r_epsilon=1e-5,
const double a_epsilon=1e-8) const;
bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;}
protected:
/**
Base Trainer for the MAP algorithm. Basically implements the e-step
*/
*/
bob::learn::em::GMMBaseTrainer m_gmm_base_trainer;
......
......@@ -244,6 +244,30 @@ int PyBobLearnEMMAPGMMTrainer_setAlpha(PyBobLearnEMMAPGMMTrainerObject* self, Py
}
static auto gmm_statistics = bob::extension::VariableDoc(
"gmm_statistics",
":py:class:`GMMStats`",
"The GMM statistics that were used internally in the E- and M-steps",
"Setting and getting the internal GMM statistics might be useful to parallelize the GMM training."
);
PyObject* PyBobLearnEMMAPGMMTrainer_get_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, void*){
BOB_TRY
PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0);
stats->cxx = self->cxx->base_trainer().getGMMStats();
return Py_BuildValue("N", stats);
BOB_CATCH_MEMBER("gmm_statistics could not be read", 0)
}
int PyBobLearnEMMAPGMMTrainer_set_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, PyObject* value, void*){
BOB_TRY
if (!PyBobLearnEMGMMStats_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name());
return -1;
}
self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx);
return 0;
BOB_CATCH_MEMBER("gmm_statistics could not be set", -1)
}
static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = {
{
......@@ -260,7 +284,13 @@ static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = {
relevance_factor.doc(),
0
},
{
gmm_statistics.name(),
(getter)PyBobLearnEMMAPGMMTrainer_get_gmm_statistics,
(setter)PyBobLearnEMMAPGMMTrainer_set_gmm_statistics,
gmm_statistics.doc(),
0
},
{0} // Sentinel
};
......@@ -303,12 +333,11 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_initialize(PyBobLearnEMMAPGMMTrainerO
/*** e_step ***/
static auto e_step = bob::extension::FunctionDoc(
"e_step",
"Calculates and saves statistics across the dataset,"
"and saves these as m_ss. ",
"Calculates and saves statistics across the dataset and saves these as :py:attr`gmm_statistics`. ",
"Calculates the average log likelihood of the observations given the GMM,"
"and returns this in average_log_likelihood."
"The statistics, m_ss, will be used in the m_step() that follows.",
"The statistics, :py:attr`gmm_statistics`, will be used in the :py:meth:`m_step` that follows.",
true
)
......@@ -357,9 +386,7 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_e_step(PyBobLearnEMMAPGMMTrainerObjec
static auto m_step = bob::extension::FunctionDoc(
"m_step",
"Performs a maximum a posteriori (MAP) update of the GMM:"
"* parameters using the accumulated statistics in :py:class:`bob.learn.em.GMMBaseTrainer.m_ss` and the"
"* parameters of the prior model",
"Performs a maximum a posteriori (MAP) update of the GMM parameters using the accumulated statistics in :py:attr:`gmm_statistics` and the parameters of the prior model",
"",
true
)
......
......@@ -144,7 +144,39 @@ static PyObject* PyBobLearnEMMLGMMTrainer_RichCompare(PyBobLearnEMMLGMMTrainerOb
/************ Variables Section ***********************************/
/******************************************************************/
static auto gmm_statistics = bob::extension::VariableDoc(
"gmm_statistics",
":py:class:`GMMStats`",
"The GMM statistics that were used internally in the E- and M-steps",
"Setting and getting the internal GMM statistics might be useful to parallelize the GMM training."
);
PyObject* PyBobLearnEMMLGMMTrainer_get_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, void*){
BOB_TRY
PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0);
stats->cxx = self->cxx->base_trainer().getGMMStats();
return Py_BuildValue("N", stats);
BOB_CATCH_MEMBER("gmm_statistics could not be read", 0)
}
int PyBobLearnEMMLGMMTrainer_set_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, PyObject* value, void*){
BOB_TRY
if (!PyBobLearnEMGMMStats_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name());
return -1;
}
self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx);
return 0;
BOB_CATCH_MEMBER("gmm_statistics could not be set", -1)
}
static PyGetSetDef PyBobLearnEMMLGMMTrainer_getseters[] = {
{
gmm_statistics.name(),
(getter)PyBobLearnEMMLGMMTrainer_get_gmm_statistics,
(setter)PyBobLearnEMMLGMMTrainer_set_gmm_statistics,
gmm_statistics.doc(),
0
},
{0} // Sentinel
};
......@@ -190,7 +222,7 @@ static auto e_step = bob::extension::FunctionDoc(
"Calculates the average log likelihood of the observations given the GMM,"
"and returns this in average_log_likelihood."
"The statistics, m_ss, will be used in the :py:func:`m_step` that follows.",
"The statistics, :py:attr:`gmm_statistics`, will be used in the :py:func:`m_step` that follows.",
true
)
......@@ -237,7 +269,7 @@ static PyObject* PyBobLearnEMMLGMMTrainer_e_step(PyBobLearnEMMLGMMTrainerObject*
static auto m_step = bob::extension::FunctionDoc(
"m_step",
"Performs a maximum likelihood (ML) update of the GMM parameters "
"using the accumulated statistics in :py:attr:`bob.learn.em.GMMBaseTrainer.m_ss`",
"using the accumulated statistics in :py:attr:`gmm_statistics`",
"See Section 9.2.2 of Bishop, \"Pattern recognition and machine learning\", 2006",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment