Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.em
Commits
0abaac5f
Commit
0abaac5f
authored
Aug 02, 2017
by
Manuel Günther
Browse files
Made EM and JFA training less verbose
parent
4e704a43
Pipeline
#11611
failed with stages
in 44 minutes and 21 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/em/train.py
View file @
0abaac5f
...
...
@@ -58,7 +58,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
average_output
=
trainer
.
compute_likelihood
(
machine
)
for
i
in
range
(
max_iterations
):
logger
.
info
(
"Iteration = %d/%d"
,
i
,
max_iterations
)
logger
.
debug
(
"Iteration = %d/%d"
,
i
+
1
,
max_iterations
)
average_output_previous
=
average_output
trainer
.
m_step
(
machine
,
data
)
trainer
.
e_step
(
machine
,
data
)
...
...
@@ -67,15 +67,16 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
average_output
=
trainer
.
compute_likelihood
(
machine
)
if
type
(
machine
)
is
bob
.
learn
.
em
.
KMeansMachine
:
logger
.
info
(
"average euclidean distance = %f"
,
average_output
)
logger
.
debug
(
"average euclidean distance = %f"
,
average_output
)
else
:
logger
.
info
(
"log likelihood = %f"
,
average_output
)
logger
.
debug
(
"log likelihood = %f"
,
average_output
)
convergence_value
=
abs
((
average_output_previous
-
average_output
)
/
average_output_previous
)
logger
.
info
(
"convergence value = %f"
,
convergence_value
)
logger
.
debug
(
"convergence value = %f"
,
convergence_value
)
# Terminates if converged (and likelihood computation is set)
if
convergence_threshold
!=
None
and
convergence_value
<=
convergence_threshold
:
logger
.
info
(
"EM training converged after %d iterations with convergence value %f"
,
convergence_value
)
break
if
hasattr
(
trainer
,
"finalize"
):
trainer
.
finalize
(
machine
,
data
)
...
...
@@ -109,7 +110,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
# V Subspace
logger
.
info
(
"V subspace estimation..."
)
for
i
in
range
(
max_iterations
):
logger
.
info
(
"Iteration = %d/%d"
,
i
,
max_iterations
)
logger
.
debug
(
"Iteration = %d/%d"
,
i
+
1
,
max_iterations
)
trainer
.
e_step_v
(
jfa_base
,
data
)
trainer
.
m_step_v
(
jfa_base
,
data
)
trainer
.
finalize_v
(
jfa_base
,
data
)
...
...
@@ -117,7 +118,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
# U subspace
logger
.
info
(
"U subspace estimation..."
)
for
i
in
range
(
max_iterations
):
logger
.
info
(
"Iteration = %d/%d"
,
i
,
max_iterations
)
logger
.
debug
(
"Iteration = %d/%d"
,
i
+
1
,
max_iterations
)
trainer
.
e_step_u
(
jfa_base
,
data
)
trainer
.
m_step_u
(
jfa_base
,
data
)
trainer
.
finalize_u
(
jfa_base
,
data
)
...
...
@@ -125,7 +126,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
# D subspace
logger
.
info
(
"D subspace estimation..."
)
for
i
in
range
(
max_iterations
):
logger
.
info
(
"Iteration = %d/%d"
,
i
,
max_iterations
)
logger
.
debug
(
"Iteration = %d/%d"
,
i
+
1
,
max_iterations
)
trainer
.
e_step_d
(
jfa_base
,
data
)
trainer
.
m_step_d
(
jfa_base
,
data
)
trainer
.
finalize_d
(
jfa_base
,
data
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment