Skip to content
Snippets Groups Projects
Commit 65ddebe8 authored by Manuel Günther's avatar Manuel Günther
Browse files

Separated IVector machine initialization from IVector trainer initialization

parent f2534aa4
No related branches found
No related tags found
No related merge requests found
...@@ -49,7 +49,17 @@ bob::learn::em::IVectorTrainer::~IVectorTrainer() ...@@ -49,7 +49,17 @@ bob::learn::em::IVectorTrainer::~IVectorTrainer()
void bob::learn::em::IVectorTrainer::initialize( void bob::learn::em::IVectorTrainer::initialize(
bob::learn::em::IVectorMachine& machine) bob::learn::em::IVectorMachine& machine)
{ {
// Initializes \f$T\f$ and \f$\Sigma\f$ of the machine
blitz::Array<double,2>& T = machine.updateT();
bob::core::array::randn(*m_rng, T);
blitz::Array<double,1>& sigma = machine.updateSigma();
sigma = machine.getUbm()->getVarianceSupervector();
machine.precompute();
}
void bob::learn::em::IVectorTrainer::resetAccumulators(const bob::learn::em::IVectorMachine& machine)
{
// Resize the accumulator
const int C = machine.getNGaussians(); const int C = machine.getNGaussians();
const int D = machine.getNInputs(); const int D = machine.getNInputs();
const int Rt = machine.getDimRt(); const int Rt = machine.getDimRt();
...@@ -75,14 +85,17 @@ void bob::learn::em::IVectorTrainer::initialize( ...@@ -75,14 +85,17 @@ void bob::learn::em::IVectorTrainer::initialize(
if (m_update_sigma) if (m_update_sigma)
m_tmp_dd1.resize(D,D); m_tmp_dd1.resize(D,D);
// Initializes \f$T\f$ and \f$\Sigma\f$ of the machine // initialize with 0
blitz::Array<double,2>& T = machine.updateT(); m_acc_Nij_wij2 = 0.;
bob::core::array::randn(*m_rng, T); m_acc_Fnormij_wij = 0.;
blitz::Array<double,1>& sigma = machine.updateSigma(); if (m_update_sigma)
sigma = machine.getUbm()->getVarianceSupervector(); {
machine.precompute(); m_acc_Nij = 0.;
m_acc_Snormij = 0.;
}
} }
void bob::learn::em::IVectorTrainer::eStep( void bob::learn::em::IVectorTrainer::eStep(
bob::learn::em::IVectorMachine& machine, bob::learn::em::IVectorMachine& machine,
const std::vector<bob::learn::em::GMMStats>& data) const std::vector<bob::learn::em::GMMStats>& data)
...@@ -91,13 +104,8 @@ void bob::learn::em::IVectorTrainer::eStep( ...@@ -91,13 +104,8 @@ void bob::learn::em::IVectorTrainer::eStep(
const int C = machine.getNGaussians(); const int C = machine.getNGaussians();
// Reinitializes accumulators to 0 // Reinitializes accumulators to 0
m_acc_Nij_wij2 = 0.; resetAccumulators(machine);
m_acc_Fnormij_wij = 0.;
if (m_update_sigma)
{
m_acc_Nij = 0.;
m_acc_Snormij = 0.;
}
for (std::vector<bob::learn::em::GMMStats>::const_iterator it = data.begin(); for (std::vector<bob::learn::em::GMMStats>::const_iterator it = data.begin();
it != data.end(); ++it) it != data.end(); ++it)
{ {
...@@ -179,7 +187,7 @@ bob::learn::em::IVectorTrainer& bob::learn::em::IVectorTrainer::operator= ...@@ -179,7 +187,7 @@ bob::learn::em::IVectorTrainer& bob::learn::em::IVectorTrainer::operator=
(const bob::learn::em::IVectorTrainer &other) (const bob::learn::em::IVectorTrainer &other)
{ {
if (this != &other) if (this != &other)
{ {
m_update_sigma = other.m_update_sigma; m_update_sigma = other.m_update_sigma;
m_acc_Nij_wij2.reference(bob::core::array::ccopy(other.m_acc_Nij_wij2)); m_acc_Nij_wij2.reference(bob::core::array::ccopy(other.m_acc_Nij_wij2));
...@@ -225,4 +233,3 @@ bool bob::learn::em::IVectorTrainer::is_similar_to ...@@ -225,4 +233,3 @@ bool bob::learn::em::IVectorTrainer::is_similar_to
bob::core::array::isClose(m_acc_Nij, other.m_acc_Nij, r_epsilon, a_epsilon) && bob::core::array::isClose(m_acc_Nij, other.m_acc_Nij, r_epsilon, a_epsilon) &&
bob::core::array::isClose(m_acc_Snormij, other.m_acc_Snormij, r_epsilon, a_epsilon); bob::core::array::isClose(m_acc_Snormij, other.m_acc_Snormij, r_epsilon, a_epsilon);
} }
...@@ -51,6 +51,12 @@ class IVectorTrainer ...@@ -51,6 +51,12 @@ class IVectorTrainer
*/ */
virtual void initialize(bob::learn::em::IVectorMachine& ivector); virtual void initialize(bob::learn::em::IVectorMachine& ivector);
/**
* @brief Reset the statistics accumulators
* to the correct size and a value of zero.
*/
void resetAccumulators(const bob::learn::em::IVectorMachine& ivector);
/** /**
* @brief Calculates statistics across the dataset, * @brief Calculates statistics across the dataset,
* and saves these as: * and saves these as:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment