Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.em
Commits
c7350c6b
Commit
c7350c6b
authored
Sep 04, 2017
by
Amir MOHAMMADI
Browse files
Don't pass rng to ML and MAP trainers
parent
262d7b62
Pipeline
#11931
passed with stages
in 16 minutes and 5 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/em/test/test_em.py
View file @
c7350c6b
...
...
@@ -55,6 +55,12 @@ def test_gmm_ML_1():
ar
=
bob
.
io
.
base
.
load
(
datafile
(
"faithful.torch3_f64.hdf5"
,
__name__
,
path
=
"../data/"
))
gmm
=
loadGMM
()
# test rng handling
ml_gmmtrainer
=
ML_GMMTrainer
(
True
,
True
,
True
)
rng
=
bob
.
core
.
random
.
mt19937
(
12345
)
bob
.
learn
.
em
.
train
(
ml_gmmtrainer
,
gmm
,
ar
,
convergence_threshold
=
0.001
,
rng
=
rng
)
gmm
=
loadGMM
()
ml_gmmtrainer
=
ML_GMMTrainer
(
True
,
True
,
True
)
#ml_gmmtrainer.train(gmm, ar)
bob
.
learn
.
em
.
train
(
ml_gmmtrainer
,
gmm
,
ar
,
convergence_threshold
=
0.001
)
...
...
@@ -114,6 +120,13 @@ def test_gmm_MAP_1():
ar
=
bob
.
io
.
base
.
load
(
datafile
(
'faithful.torch3_f64.hdf5'
,
__name__
,
path
=
"../data/"
))
# test with rng
rng
=
bob
.
core
.
random
.
mt19937
(
12345
)
gmm
=
GMMMachine
(
bob
.
io
.
base
.
HDF5File
(
datafile
(
"gmm_ML.hdf5"
,
__name__
,
path
=
"../data/"
)))
gmmprior
=
GMMMachine
(
bob
.
io
.
base
.
HDF5File
(
datafile
(
"gmm_ML.hdf5"
,
__name__
,
path
=
"../data/"
)))
map_gmmtrainer
=
MAP_GMMTrainer
(
update_means
=
True
,
update_variances
=
False
,
update_weights
=
False
,
prior_gmm
=
gmmprior
,
relevance_factor
=
4.
)
bob
.
learn
.
em
.
train
(
map_gmmtrainer
,
gmm
,
ar
,
rng
=
rng
)
gmm
=
GMMMachine
(
bob
.
io
.
base
.
HDF5File
(
datafile
(
"gmm_ML.hdf5"
,
__name__
,
path
=
"../data/"
)))
gmmprior
=
GMMMachine
(
bob
.
io
.
base
.
HDF5File
(
datafile
(
"gmm_ML.hdf5"
,
__name__
,
path
=
"../data/"
)))
...
...
@@ -253,9 +266,9 @@ def test_custom_trainer():
for
i
in
range
(
0
,
2
):
assert
(
ar
[
i
+
1
]
==
machine
.
means
[
i
,
:]).
all
()
def
test_EMPCA
():
# Tests our Probabilistic PCA trainer for linear machines for a simple
...
...
@@ -294,5 +307,5 @@ def test_EMPCA():
T
.
e_step
(
m
,
ar
)
T
.
m_step
(
m
,
ar
)
llh2
=
T
.
compute_likelihood
(
m
)
assert
abs
(
exp_llh2
-
llh2
)
<
2e-4
assert
abs
(
exp_llh2
-
llh2
)
<
2e-4
bob/learn/em/train.py
View file @
c7350c6b
...
...
@@ -45,7 +45,9 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
# Initialization
if
initialize
:
if
rng
is
not
None
:
if
rng
is
not
None
and
\
(
not
isinstance
(
trainer
,
(
bob
.
learn
.
em
.
ML_GMMTrainer
,
bob
.
learn
.
em
.
MAP_GMMTrainer
))):
trainer
.
initialize
(
machine
,
data
,
rng
)
else
:
trainer
.
initialize
(
machine
,
data
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment