Commit 84c1a36d authored by Manuel Günther's avatar Manuel Günther
Browse files

Made the GMMStats of the GMM training accessible.

parent 8619c0c1
......@@ -12,12 +12,14 @@
bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bool update_means,
const bool update_variances, const bool update_weights,
const double mean_var_update_responsibilities_threshold):
m_ss(new bob::learn::em::GMMStats()),
m_update_means(update_means), m_update_variances(update_variances),
m_update_weights(update_weights),
m_mean_var_update_responsibilities_threshold(mean_var_update_responsibilities_threshold)
{}
bob::learn::em::GMMBaseTrainer::GMMBaseTrainer(const bob::learn::em::GMMBaseTrainer& b):
m_ss(new bob::learn::em::GMMStats()),
m_update_means(b.m_update_means), m_update_variances(b.m_update_variances),
m_mean_var_update_responsibilities_threshold(b.m_mean_var_update_responsibilities_threshold)
{}
......@@ -28,20 +30,20 @@ bob::learn::em::GMMBaseTrainer::~GMMBaseTrainer()
void bob::learn::em::GMMBaseTrainer::initialize(bob::learn::em::GMMMachine& gmm)
{
// Allocate memory for the sufficient statistics and initialise
m_ss.resize(gmm.getNGaussians(),gmm.getNInputs());
m_ss->resize(gmm.getNGaussians(),gmm.getNInputs());
}
void bob::learn::em::GMMBaseTrainer::eStep(bob::learn::em::GMMMachine& gmm,
const blitz::Array<double,2>& data)
{
m_ss.init();
m_ss->init();
// Calculate the sufficient statistics and save in m_ss
gmm.accStatistics(data, m_ss);
gmm.accStatistics(data, *m_ss);
}
double bob::learn::em::GMMBaseTrainer::computeLikelihood(bob::learn::em::GMMMachine& gmm)
{
return m_ss.log_likelihood / m_ss.T;
return m_ss->log_likelihood / m_ss->T;
}
......@@ -50,7 +52,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator=
{
if (this != &other)
{
m_ss = other.m_ss;
*m_ss = *other.m_ss;
m_update_means = other.m_update_means;
m_update_variances = other.m_update_variances;
m_update_weights = other.m_update_weights;
......@@ -62,7 +64,7 @@ bob::learn::em::GMMBaseTrainer& bob::learn::em::GMMBaseTrainer::operator=
bool bob::learn::em::GMMBaseTrainer::operator==
(const bob::learn::em::GMMBaseTrainer &other) const
{
return m_ss == other.m_ss &&
return *m_ss == *other.m_ss &&
m_update_means == other.m_update_means &&
m_update_variances == other.m_update_variances &&
m_update_weights == other.m_update_weights &&
......@@ -79,7 +81,7 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to
(const bob::learn::em::GMMBaseTrainer &other, const double r_epsilon,
const double a_epsilon) const
{
return m_ss == other.m_ss &&
return *m_ss == *other.m_ss &&
m_update_means == other.m_update_means &&
m_update_variances == other.m_update_variances &&
m_update_weights == other.m_update_weights &&
......@@ -87,8 +89,8 @@ bool bob::learn::em::GMMBaseTrainer::is_similar_to
other.m_mean_var_update_responsibilities_threshold, r_epsilon, a_epsilon);
}
void bob::learn::em::GMMBaseTrainer::setGMMStats(const bob::learn::em::GMMStats& stats)
void bob::learn::em::GMMBaseTrainer::setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats)
{
bob::core::array::assertSameShape(m_ss.sumPx, stats.sumPx);
bob::core::array::assertSameShape(m_ss->sumPx, stats->sumPx);
m_ss = stats;
}
......@@ -14,8 +14,8 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(
const bool update_weights,
const double mean_var_update_responsibilities_threshold,
const bool reynolds_adaptation,
const double relevance_factor,
const bool reynolds_adaptation,
const double relevance_factor,
const double alpha,
boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm):
......@@ -33,7 +33,7 @@ bob::learn::em::MAP_GMMTrainer::MAP_GMMTrainer(const bob::learn::em::MAP_GMMTrai
m_prior_gmm(b.m_prior_gmm)
{
m_relevance_factor = b.m_relevance_factor;
m_alpha = b.m_alpha;
m_alpha = b.m_alpha;
m_reynolds_adaptation = b.m_reynolds_adaptation;
}
......@@ -75,10 +75,10 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
{
// Read options and variables
double n_gaussians = gmm.getNGaussians();
//Checking if it is necessary to resize the cache
if((size_t)m_cache_alpha.extent(0) != n_gaussians)
initialize(gmm); //If it is different for some reason, there is no way, you have to initialize
initialize(gmm); //If it is different for some reason, there is no way, you have to initialize
// Check that the prior GMM has been specified
if (!m_prior_gmm)
......@@ -92,13 +92,13 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
if (!m_reynolds_adaptation)
m_cache_alpha = m_alpha;
else
m_cache_alpha = m_gmm_base_trainer.getGMMStats().n(i) / (m_gmm_base_trainer.getGMMStats().n(i) + m_relevance_factor);
m_cache_alpha = m_gmm_base_trainer.getGMMStats()->n(i) / (m_gmm_base_trainer.getGMMStats()->n(i) + m_relevance_factor);
// - Update weights if requested
// Equation 11 of Reynolds et al., "Speaker Verification Using Adapted Gaussian Mixture Models", Digital Signal Processing, 2000
if (m_gmm_base_trainer.getUpdateWeights()) {
// Calculate the maximum likelihood weights
m_cache_ml_weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
m_cache_ml_weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx
// Get the prior weights
const blitz::Array<double,1>& prior_weights = m_prior_gmm->getWeights();
......@@ -123,12 +123,12 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
for (size_t i=0; i<n_gaussians; ++i) {
const blitz::Array<double,1>& prior_means = m_prior_gmm->getGaussian(i)->getMean();
blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
means = prior_means;
}
else {
// Use the maximum likelihood means
means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats().sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i)) + (1-m_cache_alpha(i)) * prior_means;
means = m_cache_alpha(i) * (m_gmm_base_trainer.getGMMStats()->sumPx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i)) + (1-m_cache_alpha(i)) * prior_means;
}
}
}
......@@ -142,11 +142,11 @@ void bob::learn::em::MAP_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
const blitz::Array<double,1>& prior_variances = m_prior_gmm->getGaussian(i)->getVariance();
blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
if (m_gmm_base_trainer.getGMMStats().n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
if (m_gmm_base_trainer.getGMMStats()->n(i) < m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold()) {
variances = (prior_variances + prior_means) - blitz::pow2(means);
}
else {
variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats().sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats().n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
variances = m_cache_alpha(i) * m_gmm_base_trainer.getGMMStats()->sumPxx(i,blitz::Range::all()) / m_gmm_base_trainer.getGMMStats()->n(i) + (1-m_cache_alpha(i)) * (prior_variances + prior_means) - blitz::pow2(means);
}
gmm.getGaussian(i)->applyVarianceThresholds();
}
......@@ -200,4 +200,3 @@ bool bob::learn::em::MAP_GMMTrainer::is_similar_to
bob::core::isClose(m_alpha, other.m_alpha, r_epsilon, a_epsilon) &&
m_reynolds_adaptation == other.m_reynolds_adaptation;
}
......@@ -10,7 +10,7 @@
bob::learn::em::ML_GMMTrainer::ML_GMMTrainer(
const bool update_means,
const bool update_variances,
const bool update_variances,
const bool update_weights,
const double mean_var_update_responsibilities_threshold
):
......@@ -29,7 +29,7 @@ bob::learn::em::ML_GMMTrainer::~ML_GMMTrainer()
void bob::learn::em::ML_GMMTrainer::initialize(bob::learn::em::GMMMachine& gmm)
{
m_gmm_base_trainer.initialize(gmm);
// Allocate cache
size_t n_gaussians = gmm.getNGaussians();
m_cache_ss_n_thresholded.resize(n_gaussians);
......@@ -49,14 +49,14 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
// Equation 9.26 of Bishop, "Pattern recognition and machine learning", 2006
if (m_gmm_base_trainer.getUpdateWeights()) {
blitz::Array<double,1>& weights = gmm.updateWeights();
weights = m_gmm_base_trainer.getGMMStats().n / static_cast<double>(m_gmm_base_trainer.getGMMStats().T); //cast req. for linux/32-bits & osx
weights = m_gmm_base_trainer.getGMMStats()->n / static_cast<double>(m_gmm_base_trainer.getGMMStats()->T); //cast req. for linux/32-bits & osx
// Recompute the log weights in the cache of the GMMMachine
gmm.recomputeLogWeights();
}
// Generate a thresholded version of m_ss.n
for(size_t i=0; i<n_gaussians; ++i)
m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats().n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
m_cache_ss_n_thresholded(i) = std::max(m_gmm_base_trainer.getGMMStats()->n(i), m_gmm_base_trainer.getMeanVarUpdateResponsibilitiesThreshold());
// Update GMM parameters using the sufficient statistics (m_ss)
// - Update means if requested
......@@ -64,7 +64,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
if (m_gmm_base_trainer.getUpdateMeans()) {
for(size_t i=0; i<n_gaussians; ++i) {
blitz::Array<double,1>& means = gmm.getGaussian(i)->updateMean();
means = m_gmm_base_trainer.getGMMStats().sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
means = m_gmm_base_trainer.getGMMStats()->sumPx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i);
}
}
......@@ -77,7 +77,7 @@ void bob::learn::em::ML_GMMTrainer::mStep(bob::learn::em::GMMMachine& gmm)
for(size_t i=0; i<n_gaussians; ++i) {
const blitz::Array<double,1>& means = gmm.getGaussian(i)->getMean();
blitz::Array<double,1>& variances = gmm.getGaussian(i)->updateVariance();
variances = m_gmm_base_trainer.getGMMStats().sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
variances = m_gmm_base_trainer.getGMMStats()->sumPxx(i, blitz::Range::all()) / m_cache_ss_n_thresholded(i) - blitz::pow2(means);
gmm.getGaussian(i)->applyVarianceThresholds();
}
}
......
......@@ -30,7 +30,7 @@ class GMMBaseTrainer
* @brief Default constructor
*/
GMMBaseTrainer(const bool update_means=true,
const bool update_variances=false,
const bool update_variances=false,
const bool update_weights=false,
const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
......@@ -93,21 +93,21 @@ class GMMBaseTrainer
* @brief Returns the internal GMM statistics. Useful to parallelize the
* E-step
*/
const bob::learn::em::GMMStats getGMMStats() const
const boost::shared_ptr<bob::learn::em::GMMStats> getGMMStats() const
{ return m_ss; }
/**
* @brief Sets the internal GMM statistics. Useful to parallelize the
* E-step
*/
void setGMMStats(const bob::learn::em::GMMStats& stats);
void setGMMStats(boost::shared_ptr<bob::learn::em::GMMStats> stats);
/**
* update means on each iteration
*/
*/
bool getUpdateMeans()
{return m_update_means;}
/**
* update variances on each iteration
*/
......@@ -117,19 +117,19 @@ class GMMBaseTrainer
bool getUpdateWeights()
{return m_update_weights;}
double getMeanVarUpdateResponsibilitiesThreshold()
{return m_mean_var_update_responsibilities_threshold;}
private:
/**
* These are the sufficient statistics, calculated during the
* E-step and used during the M-step
*/
bob::learn::em::GMMStats m_ss;
boost::shared_ptr<bob::learn::em::GMMStats> m_ss;
/**
......
......@@ -28,11 +28,11 @@ class MAP_GMMTrainer
*/
MAP_GMMTrainer(
const bool update_means=true,
const bool update_variances=false,
const bool update_variances=false,
const bool update_weights=false,
const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon(),
const bool reynolds_adaptation=false,
const double relevance_factor=4,
const bool reynolds_adaptation=false,
const double relevance_factor=4,
const double alpha=0.5,
boost::shared_ptr<bob::learn::em::GMMMachine> prior_gmm = boost::shared_ptr<bob::learn::em::GMMMachine>());
......@@ -108,14 +108,14 @@ class MAP_GMMTrainer
*/
double computeLikelihood(bob::learn::em::GMMMachine& gmm){
return m_gmm_base_trainer.computeLikelihood(gmm);
}
}
bool getReynoldsAdaptation()
{return m_reynolds_adaptation;}
void setReynoldsAdaptation(const bool reynolds_adaptation)
{m_reynolds_adaptation = reynolds_adaptation;}
double getRelevanceFactor()
{return m_relevance_factor;}
......@@ -130,6 +130,8 @@ class MAP_GMMTrainer
void setAlpha(const double alpha)
{m_alpha = alpha;}
bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;}
protected:
......@@ -140,7 +142,7 @@ class MAP_GMMTrainer
/**
Base Trainer for the MAP algorithm. Basically implements the e-step
*/
*/
bob::learn::em::GMMBaseTrainer m_gmm_base_trainer;
/**
......
......@@ -28,7 +28,7 @@ class ML_GMMTrainer{
* @brief Default constructor
*/
ML_GMMTrainer(const bool update_means=true,
const bool update_variances=false,
const bool update_variances=false,
const bool update_weights=false,
const double mean_var_update_responsibilities_threshold = std::numeric_limits<double>::epsilon());
......@@ -97,13 +97,15 @@ class ML_GMMTrainer{
*/
bool is_similar_to(const ML_GMMTrainer& b, const double r_epsilon=1e-5,
const double a_epsilon=1e-8) const;
bob::learn::em::GMMBaseTrainer& base_trainer(){return m_gmm_base_trainer;}
protected:
/**
Base Trainer for the MAP algorithm. Basically implements the e-step
*/
*/
bob::learn::em::GMMBaseTrainer m_gmm_base_trainer;
......
......@@ -244,6 +244,30 @@ int PyBobLearnEMMAPGMMTrainer_setAlpha(PyBobLearnEMMAPGMMTrainerObject* self, Py
}
static auto gmm_statistics = bob::extension::VariableDoc(
"gmm_statistics",
":py:class:`GMMStats`",
"The GMM statistics that were used internally in the E- and M-steps",
"Setting and getting the internal GMM statistics might be useful to parallelize the GMM training."
);
PyObject* PyBobLearnEMMAPGMMTrainer_get_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, void*){
BOB_TRY
PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0);
stats->cxx = self->cxx->base_trainer().getGMMStats();
return Py_BuildValue("N", stats);
BOB_CATCH_MEMBER("gmm_statistics could not be read", 0)
}
int PyBobLearnEMMAPGMMTrainer_set_gmm_statistics(PyBobLearnEMMAPGMMTrainerObject* self, PyObject* value, void*){
BOB_TRY
if (!PyBobLearnEMGMMStats_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name());
return -1;
}
self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx);
return 0;
BOB_CATCH_MEMBER("gmm_statistics could not be set", -1)
}
static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = {
{
......@@ -260,7 +284,13 @@ static PyGetSetDef PyBobLearnEMMAPGMMTrainer_getseters[] = {
relevance_factor.doc(),
0
},
{
gmm_statistics.name(),
(getter)PyBobLearnEMMAPGMMTrainer_get_gmm_statistics,
(setter)PyBobLearnEMMAPGMMTrainer_set_gmm_statistics,
gmm_statistics.doc(),
0
},
{0} // Sentinel
};
......@@ -303,12 +333,11 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_initialize(PyBobLearnEMMAPGMMTrainerO
/*** e_step ***/
static auto e_step = bob::extension::FunctionDoc(
"e_step",
"Calculates and saves statistics across the dataset,"
"and saves these as m_ss. ",
"Calculates and saves statistics across the dataset and saves these as :py:attr`gmm_statistics`. ",
"Calculates the average log likelihood of the observations given the GMM,"
"and returns this in average_log_likelihood."
"The statistics, m_ss, will be used in the m_step() that follows.",
"The statistics, :py:attr`gmm_statistics`, will be used in the :py:meth:`m_step` that follows.",
true
)
......@@ -357,9 +386,7 @@ static PyObject* PyBobLearnEMMAPGMMTrainer_e_step(PyBobLearnEMMAPGMMTrainerObjec
static auto m_step = bob::extension::FunctionDoc(
"m_step",
"Performs a maximum a posteriori (MAP) update of the GMM:"
"* parameters using the accumulated statistics in :py:class:`bob.learn.em.GMMBaseTrainer.m_ss` and the"
"* parameters of the prior model",
"Performs a maximum a posteriori (MAP) update of the GMM parameters using the accumulated statistics in :py:attr:`gmm_statistics` and the parameters of the prior model",
"",
true
)
......
......@@ -144,7 +144,39 @@ static PyObject* PyBobLearnEMMLGMMTrainer_RichCompare(PyBobLearnEMMLGMMTrainerOb
/************ Variables Section ***********************************/
/******************************************************************/
static auto gmm_statistics = bob::extension::VariableDoc(
"gmm_statistics",
":py:class:`GMMStats`",
"The GMM statistics that were used internally in the E- and M-steps",
"Setting and getting the internal GMM statistics might be useful to parallelize the GMM training."
);
PyObject* PyBobLearnEMMLGMMTrainer_get_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, void*){
BOB_TRY
PyBobLearnEMGMMStatsObject* stats = (PyBobLearnEMGMMStatsObject*)PyBobLearnEMGMMStats_Type.tp_alloc(&PyBobLearnEMGMMStats_Type, 0);
stats->cxx = self->cxx->base_trainer().getGMMStats();
return Py_BuildValue("N", stats);
BOB_CATCH_MEMBER("gmm_statistics could not be read", 0)
}
int PyBobLearnEMMLGMMTrainer_set_gmm_statistics(PyBobLearnEMMLGMMTrainerObject* self, PyObject* value, void*){
BOB_TRY
if (!PyBobLearnEMGMMStats_Check(value)){
PyErr_Format(PyExc_RuntimeError, "%s %s expects a GMMStats object", Py_TYPE(self)->tp_name, gmm_statistics.name());
return -1;
}
self->cxx->base_trainer().setGMMStats(reinterpret_cast<PyBobLearnEMGMMStatsObject*>(value)->cxx);
return 0;
BOB_CATCH_MEMBER("gmm_statistics could not be set", -1)
}
static PyGetSetDef PyBobLearnEMMLGMMTrainer_getseters[] = {
{
gmm_statistics.name(),
(getter)PyBobLearnEMMLGMMTrainer_get_gmm_statistics,
(setter)PyBobLearnEMMLGMMTrainer_set_gmm_statistics,
gmm_statistics.doc(),
0
},
{0} // Sentinel
};
......@@ -190,7 +222,7 @@ static auto e_step = bob::extension::FunctionDoc(
"Calculates the average log likelihood of the observations given the GMM,"
"and returns this in average_log_likelihood."
"The statistics, m_ss, will be used in the :py:func:`m_step` that follows.",
"The statistics, :py:attr:`gmm_statistics`, will be used in the :py:func:`m_step` that follows.",
true
)
......@@ -237,7 +269,7 @@ static PyObject* PyBobLearnEMMLGMMTrainer_e_step(PyBobLearnEMMLGMMTrainerObject*
static auto m_step = bob::extension::FunctionDoc(
"m_step",
"Performs a maximum likelihood (ML) update of the GMM parameters "
"using the accumulated statistics in :py:attr:`bob.learn.em.GMMBaseTrainer.m_ss`",
"using the accumulated statistics in :py:attr:`gmm_statistics`",
"See Section 9.2.2 of Bishop, \"Pattern recognition and machine learning\", 2006",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment