Commit 001a8a53 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Binded IVector Trainer

parent ae73d4cd
......@@ -229,6 +229,7 @@ void bob::learn::misc::IVectorMachine::computeTtSigmaInvFnorm(
m_tmp_d = gs.sumPx(c,rall) - gs.n(c) * m_ubm->getGaussian(c)->getMean();
blitz::Array<double,2> Tct_sigmacInv = m_cache_Tct_sigmacInv(c, rall, rall);
bob::math::prod(Tct_sigmacInv, m_tmp_d, m_tmp_t2);
output += m_tmp_t2;
}
}
......
......@@ -49,6 +49,7 @@ bob::learn::misc::IVectorTrainer::~IVectorTrainer()
void bob::learn::misc::IVectorTrainer::initialize(
bob::learn::misc::IVectorMachine& machine)
{
const int C = machine.getNGaussians();
const int D = machine.getNInputs();
const int Rt = machine.getDimRt();
......@@ -67,6 +68,7 @@ void bob::learn::misc::IVectorTrainer::initialize(
m_tmp_wij2.resize(Rt,Rt);
m_tmp_d1.resize(D);
m_tmp_t1.resize(Rt);
m_tmp_dt1.resize(D,Rt);
m_tmp_tt1.resize(Rt,Rt);
m_tmp_tt2.resize(Rt,Rt);
......@@ -105,6 +107,7 @@ void bob::learn::misc::IVectorTrainer::eStep(
// b. Computes \f$Id + T^{T} \Sigma^{-1} T\f$
machine.computeIdTtSigmaInvT(*it, m_tmp_tt1);
// c. Computes \f$(Id + T^{T} \Sigma^{-1} T)^{-1}\f$
bob::math::inv(m_tmp_tt1, m_tmp_tt2);
// d. Computes \f$E{wij} = (Id + T^{T} \Sigma^{-1} T)^{-1} T^{T} \Sigma^{-1} F_{norm}\f$
bob::math::prod(m_tmp_tt2, m_tmp_t1, m_tmp_wij); // E{wij}
......
......@@ -26,9 +26,8 @@ static int extract_GMMStats_1d(PyObject *list,
PyErr_Format(PyExc_RuntimeError, "Expected GMMStats objects");
return -1;
}
bob::learn::misc::GMMStats *stats_pointer = stats->cxx.get();
std::cout << " #### " << std::endl;
training_data.push_back(*(stats_pointer));
training_data.push_back(*stats->cxx);
}
return 0;
}
......@@ -360,9 +359,8 @@ static PyObject* PyBobLearnMiscIVectorTrainer_e_step(PyBobLearnMiscIVectorTraine
if(extract_GMMStats_1d(stats ,training_data)==0)
self->cxx->eStep(*ivector_machine->cxx, training_data);
BOB_CATCH_MEMBER("cannot perform the e_step method", 0)
Py_RETURN_NONE;
BOB_CATCH_MEMBER("cannot perform the e_step method", 0)
}
......
......@@ -11,7 +11,7 @@ import numpy
import numpy.linalg
import numpy.random
from . import GMMMachine, GMMStats, IVectorMachine, IVectorTrainer
from bob.learn.misc import GMMMachine, GMMStats, IVectorMachine, IVectorTrainer
### Test class inspired by an implementation of Chris McCool
### Chris McCool (chris.mccool@nicta.com.au)
......@@ -229,7 +229,7 @@ def test_trainer_nosigma():
# Initialization
trainer = IVectorTrainer()
trainer.initialize(m, data)
trainer.initialize(m)
m.t = t
m.sigma = sigma
for it in range(2):
......@@ -241,7 +241,7 @@ def test_trainer_nosigma():
assert numpy.allclose(acc_Fnorm_Sigma_wij_ref[it][k], trainer.acc_fnormij_wij[k], 1e-5)
# M-Step
trainer.m_step(m, data)
trainer.m_step(m)
assert numpy.allclose(t_ref[it], m.t, 1e-5)
def test_trainer_update_sigma():
......@@ -343,7 +343,7 @@ def test_trainer_update_sigma():
# Initialization
trainer = IVectorTrainer(update_sigma=True)
trainer.initialize(m, data)
trainer.initialize(m)
m.t = t
m.sigma = sigma
for it in range(2):
......@@ -357,7 +357,7 @@ def test_trainer_update_sigma():
assert numpy.allclose(N_ref[it], trainer.acc_nij, 1e-5)
# M-Step
trainer.m_step(m, data)
trainer.m_step(m)
assert numpy.allclose(t_ref[it], m.t, 1e-5)
assert numpy.allclose(sigma_ref[it], m.sigma, 1e-5)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment