Commit 65ddebe8 authored by Manuel Günther's avatar Manuel Günther
Browse files

Separated IVector machine initialization from IVector trainer initialization

parent f2534aa4
......@@ -49,7 +49,17 @@ bob::learn::em::IVectorTrainer::~IVectorTrainer()
void bob::learn::em::IVectorTrainer::initialize(
bob::learn::em::IVectorMachine& machine)
{
// Initializes \f$T\f$ and \f$\Sigma\f$ of the machine
blitz::Array<double,2>& T = machine.updateT();
bob::core::array::randn(*m_rng, T);
blitz::Array<double,1>& sigma = machine.updateSigma();
sigma = machine.getUbm()->getVarianceSupervector();
machine.precompute();
}
void bob::learn::em::IVectorTrainer::resetAccumulators(const bob::learn::em::IVectorMachine& machine)
{
// Resize the accumulator
const int C = machine.getNGaussians();
const int D = machine.getNInputs();
const int Rt = machine.getDimRt();
......@@ -75,14 +85,17 @@ void bob::learn::em::IVectorTrainer::initialize(
if (m_update_sigma)
m_tmp_dd1.resize(D,D);
// Initializes \f$T\f$ and \f$\Sigma\f$ of the machine
blitz::Array<double,2>& T = machine.updateT();
bob::core::array::randn(*m_rng, T);
blitz::Array<double,1>& sigma = machine.updateSigma();
sigma = machine.getUbm()->getVarianceSupervector();
machine.precompute();
// initialize with 0
m_acc_Nij_wij2 = 0.;
m_acc_Fnormij_wij = 0.;
if (m_update_sigma)
{
m_acc_Nij = 0.;
m_acc_Snormij = 0.;
}
}
void bob::learn::em::IVectorTrainer::eStep(
bob::learn::em::IVectorMachine& machine,
const std::vector<bob::learn::em::GMMStats>& data)
......@@ -91,13 +104,8 @@ void bob::learn::em::IVectorTrainer::eStep(
const int C = machine.getNGaussians();
// Reinitializes accumulators to 0
m_acc_Nij_wij2 = 0.;
m_acc_Fnormij_wij = 0.;
if (m_update_sigma)
{
m_acc_Nij = 0.;
m_acc_Snormij = 0.;
}
resetAccumulators(machine);
for (std::vector<bob::learn::em::GMMStats>::const_iterator it = data.begin();
it != data.end(); ++it)
{
......@@ -179,7 +187,7 @@ bob::learn::em::IVectorTrainer& bob::learn::em::IVectorTrainer::operator=
(const bob::learn::em::IVectorTrainer &other)
{
if (this != &other)
{
{
m_update_sigma = other.m_update_sigma;
m_acc_Nij_wij2.reference(bob::core::array::ccopy(other.m_acc_Nij_wij2));
......@@ -225,4 +233,3 @@ bool bob::learn::em::IVectorTrainer::is_similar_to
bob::core::array::isClose(m_acc_Nij, other.m_acc_Nij, r_epsilon, a_epsilon) &&
bob::core::array::isClose(m_acc_Snormij, other.m_acc_Snormij, r_epsilon, a_epsilon);
}
......@@ -51,6 +51,12 @@ class IVectorTrainer
*/
virtual void initialize(bob::learn::em::IVectorMachine& ivector);
/**
* @brief Reset the statistics accumulators
* to the correct size and a value of zero.
*/
void resetAccumulators(const bob::learn::em::IVectorMachine& ivector);
/**
* @brief Calculates statistics across the dataset,
* and saves these as:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment