diff --git a/bob/learn/em/cpp/IVectorTrainer.cpp b/bob/learn/em/cpp/IVectorTrainer.cpp index 82ad2504a3d09881b807d4b98c1effc0566d460c..a42bb2ba7033f398e83de29caef03236334673eb 100644 --- a/bob/learn/em/cpp/IVectorTrainer.cpp +++ b/bob/learn/em/cpp/IVectorTrainer.cpp @@ -49,7 +49,17 @@ bob::learn::em::IVectorTrainer::~IVectorTrainer() void bob::learn::em::IVectorTrainer::initialize( bob::learn::em::IVectorMachine& machine) { + // Initializes \f$T\f$ and \f$\Sigma\f$ of the machine + blitz::Array<double,2>& T = machine.updateT(); + bob::core::array::randn(*m_rng, T); + blitz::Array<double,1>& sigma = machine.updateSigma(); + sigma = machine.getUbm()->getVarianceSupervector(); + machine.precompute(); +} +void bob::learn::em::IVectorTrainer::resetAccumulators(const bob::learn::em::IVectorMachine& machine) +{ + // Resize the accumulator const int C = machine.getNGaussians(); const int D = machine.getNInputs(); const int Rt = machine.getDimRt(); @@ -75,14 +85,17 @@ void bob::learn::em::IVectorTrainer::initialize( if (m_update_sigma) m_tmp_dd1.resize(D,D); - // Initializes \f$T\f$ and \f$\Sigma\f$ of the machine - blitz::Array<double,2>& T = machine.updateT(); - bob::core::array::randn(*m_rng, T); - blitz::Array<double,1>& sigma = machine.updateSigma(); - sigma = machine.getUbm()->getVarianceSupervector(); - machine.precompute(); + // initialize with 0 + m_acc_Nij_wij2 = 0.; + m_acc_Fnormij_wij = 0.; + if (m_update_sigma) + { + m_acc_Nij = 0.; + m_acc_Snormij = 0.; + } } + void bob::learn::em::IVectorTrainer::eStep( bob::learn::em::IVectorMachine& machine, const std::vector<bob::learn::em::GMMStats>& data) @@ -91,13 +104,8 @@ void bob::learn::em::IVectorTrainer::eStep( const int C = machine.getNGaussians(); // Reinitializes accumulators to 0 - m_acc_Nij_wij2 = 0.; - m_acc_Fnormij_wij = 0.; - if (m_update_sigma) - { - m_acc_Nij = 0.; - m_acc_Snormij = 0.; - } + resetAccumulators(machine); + for (std::vector<bob::learn::em::GMMStats>::const_iterator it = data.begin(); it != data.end(); ++it) { @@ -179,7 +187,7 @@ bob::learn::em::IVectorTrainer& bob::learn::em::IVectorTrainer::operator= (const bob::learn::em::IVectorTrainer &other) { if (this != &other) - { + { m_update_sigma = other.m_update_sigma; m_acc_Nij_wij2.reference(bob::core::array::ccopy(other.m_acc_Nij_wij2)); @@ -225,4 +233,3 @@ bool bob::learn::em::IVectorTrainer::is_similar_to bob::core::array::isClose(m_acc_Nij, other.m_acc_Nij, r_epsilon, a_epsilon) && bob::core::array::isClose(m_acc_Snormij, other.m_acc_Snormij, r_epsilon, a_epsilon); } - diff --git a/bob/learn/em/include/bob.learn.em/IVectorTrainer.h b/bob/learn/em/include/bob.learn.em/IVectorTrainer.h index 4f18226bafaf41ea47b765d099a85cd1c55f8e49..0482f1513116289ecf1242bfa20d8e4856d4c7db 100644 --- a/bob/learn/em/include/bob.learn.em/IVectorTrainer.h +++ b/bob/learn/em/include/bob.learn.em/IVectorTrainer.h @@ -51,6 +51,12 @@ class IVectorTrainer */ virtual void initialize(bob::learn::em::IVectorMachine& ivector); + /** + * @brief Reset the statistics accumulators + * to the correct size and a value of zero. + */ + void resetAccumulators(const bob::learn::em::IVectorMachine& ivector); + /** * @brief Calculates statistics across the dataset, * and saves these as: