Skip to content
Snippets Groups Projects
Commit 0abaac5f authored by Manuel Günther's avatar Manuel Günther
Browse files

Made EM and JFA training less verbose

parent 4e704a43
No related branches found
No related tags found
1 merge request!30Resolve "training is very verbose"
Pipeline #
......@@ -58,7 +58,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
average_output = trainer.compute_likelihood(machine)
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
logger.debug("Iteration = %d/%d", i+1, max_iterations)
average_output_previous = average_output
trainer.m_step(machine, data)
trainer.e_step(machine, data)
......@@ -67,15 +67,16 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
average_output = trainer.compute_likelihood(machine)
if type(machine) is bob.learn.em.KMeansMachine:
logger.info("average euclidean distance = %f", average_output)
logger.debug("average euclidean distance = %f", average_output)
else:
logger.info("log likelihood = %f", average_output)
logger.debug("log likelihood = %f", average_output)
convergence_value = abs((average_output_previous - average_output) / average_output_previous)
logger.info("convergence value = %f", convergence_value)
logger.debug("convergence value = %f", convergence_value)
# Terminates if converged (and likelihood computation is set)
if convergence_threshold != None and convergence_value <= convergence_threshold:
logger.info("EM training converged after %d iterations with convergence value %f", convergence_value)
break
if hasattr(trainer, "finalize"):
trainer.finalize(machine, data)
......@@ -109,7 +110,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
# V Subspace
logger.info("V subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
logger.debug("Iteration = %d/%d", i+1, max_iterations)
trainer.e_step_v(jfa_base, data)
trainer.m_step_v(jfa_base, data)
trainer.finalize_v(jfa_base, data)
......@@ -117,7 +118,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
# U subspace
logger.info("U subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
logger.debug("Iteration = %d/%d", i+1, max_iterations)
trainer.e_step_u(jfa_base, data)
trainer.m_step_u(jfa_base, data)
trainer.finalize_u(jfa_base, data)
......@@ -125,7 +126,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
# D subspace
logger.info("D subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
logger.debug("Iteration = %d/%d", i+1, max_iterations)
trainer.e_step_d(jfa_base, data)
trainer.m_step_d(jfa_base, data)
trainer.finalize_d(jfa_base, data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment