Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.em
Commits
001a8a53
Commit
001a8a53
authored
Feb 03, 2015
by
Tiago de Freitas Pereira
Browse files
Binded IVector Trainer
parent
ae73d4cd
Changes
4
Hide whitespace changes
Inline
Side-by-side
bob/learn/misc/cpp/IVectorMachine.cpp
View file @
001a8a53
...
...
@@ -229,6 +229,7 @@ void bob::learn::misc::IVectorMachine::computeTtSigmaInvFnorm(
m_tmp_d
=
gs
.
sumPx
(
c
,
rall
)
-
gs
.
n
(
c
)
*
m_ubm
->
getGaussian
(
c
)
->
getMean
();
blitz
::
Array
<
double
,
2
>
Tct_sigmacInv
=
m_cache_Tct_sigmacInv
(
c
,
rall
,
rall
);
bob
::
math
::
prod
(
Tct_sigmacInv
,
m_tmp_d
,
m_tmp_t2
);
output
+=
m_tmp_t2
;
}
}
...
...
bob/learn/misc/cpp/IVectorTrainer.cpp
View file @
001a8a53
...
...
@@ -49,6 +49,7 @@ bob::learn::misc::IVectorTrainer::~IVectorTrainer()
void
bob
::
learn
::
misc
::
IVectorTrainer
::
initialize
(
bob
::
learn
::
misc
::
IVectorMachine
&
machine
)
{
const
int
C
=
machine
.
getNGaussians
();
const
int
D
=
machine
.
getNInputs
();
const
int
Rt
=
machine
.
getDimRt
();
...
...
@@ -67,6 +68,7 @@ void bob::learn::misc::IVectorTrainer::initialize(
m_tmp_wij2
.
resize
(
Rt
,
Rt
);
m_tmp_d1
.
resize
(
D
);
m_tmp_t1
.
resize
(
Rt
);
m_tmp_dt1
.
resize
(
D
,
Rt
);
m_tmp_tt1
.
resize
(
Rt
,
Rt
);
m_tmp_tt2
.
resize
(
Rt
,
Rt
);
...
...
@@ -105,6 +107,7 @@ void bob::learn::misc::IVectorTrainer::eStep(
// b. Computes \f$Id + T^{T} \Sigma^{-1} T\f$
machine
.
computeIdTtSigmaInvT
(
*
it
,
m_tmp_tt1
);
// c. Computes \f$(Id + T^{T} \Sigma^{-1} T)^{-1}\f$
bob
::
math
::
inv
(
m_tmp_tt1
,
m_tmp_tt2
);
// d. Computes \f$E{wij} = (Id + T^{T} \Sigma^{-1} T)^{-1} T^{T} \Sigma^{-1} F_{norm}\f$
bob
::
math
::
prod
(
m_tmp_tt2
,
m_tmp_t1
,
m_tmp_wij
);
// E{wij}
...
...
bob/learn/misc/ivector_trainer.cpp
View file @
001a8a53
...
...
@@ -26,9 +26,8 @@ static int extract_GMMStats_1d(PyObject *list,
PyErr_Format
(
PyExc_RuntimeError
,
"Expected GMMStats objects"
);
return
-
1
;
}
bob
::
learn
::
misc
::
GMMStats
*
stats_pointer
=
stats
->
cxx
.
get
();
std
::
cout
<<
" #### "
<<
std
::
endl
;
training_data
.
push_back
(
*
(
stats_pointer
));
training_data
.
push_back
(
*
stats
->
cxx
);
}
return
0
;
}
...
...
@@ -360,9 +359,8 @@ static PyObject* PyBobLearnMiscIVectorTrainer_e_step(PyBobLearnMiscIVectorTraine
if
(
extract_GMMStats_1d
(
stats
,
training_data
)
==
0
)
self
->
cxx
->
eStep
(
*
ivector_machine
->
cxx
,
training_data
);
BOB_CATCH_MEMBER
(
"cannot perform the e_step method"
,
0
)
Py_RETURN_NONE
;
BOB_CATCH_MEMBER
(
"cannot perform the e_step method"
,
0
)
}
...
...
bob/learn/misc/test_ivector_trainer.py
View file @
001a8a53
...
...
@@ -11,7 +11,7 @@ import numpy
import
numpy.linalg
import
numpy.random
from
.
import
GMMMachine
,
GMMStats
,
IVectorMachine
,
IVectorTrainer
from
bob.learn.misc
import
GMMMachine
,
GMMStats
,
IVectorMachine
,
IVectorTrainer
### Test class inspired by an implementation of Chris McCool
### Chris McCool (chris.mccool@nicta.com.au)
...
...
@@ -229,7 +229,7 @@ def test_trainer_nosigma():
# Initialization
trainer
=
IVectorTrainer
()
trainer
.
initialize
(
m
,
data
)
trainer
.
initialize
(
m
)
m
.
t
=
t
m
.
sigma
=
sigma
for
it
in
range
(
2
):
...
...
@@ -241,7 +241,7 @@ def test_trainer_nosigma():
assert
numpy
.
allclose
(
acc_Fnorm_Sigma_wij_ref
[
it
][
k
],
trainer
.
acc_fnormij_wij
[
k
],
1e-5
)
# M-Step
trainer
.
m_step
(
m
,
data
)
trainer
.
m_step
(
m
)
assert
numpy
.
allclose
(
t_ref
[
it
],
m
.
t
,
1e-5
)
def
test_trainer_update_sigma
():
...
...
@@ -343,7 +343,7 @@ def test_trainer_update_sigma():
# Initialization
trainer
=
IVectorTrainer
(
update_sigma
=
True
)
trainer
.
initialize
(
m
,
data
)
trainer
.
initialize
(
m
)
m
.
t
=
t
m
.
sigma
=
sigma
for
it
in
range
(
2
):
...
...
@@ -357,7 +357,7 @@ def test_trainer_update_sigma():
assert
numpy
.
allclose
(
N_ref
[
it
],
trainer
.
acc_nij
,
1e-5
)
# M-Step
trainer
.
m_step
(
m
,
data
)
trainer
.
m_step
(
m
)
assert
numpy
.
allclose
(
t_ref
[
it
],
m
.
t
,
1e-5
)
assert
numpy
.
allclose
(
sigma_ref
[
it
],
m
.
sigma
,
1e-5
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment