Commit e98f7784 authored by Elie KHOURY's avatar Elie KHOURY
Browse files

improving the logging info... and correcting division per zero error

parent af6fd843
......@@ -10,6 +10,7 @@ import logging
logger = logging.getLogger('bob.learn.em')
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
"""
Trains a machine given a trainer and the proper data
......@@ -40,26 +41,24 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
average_output = 0
average_output_previous = 0
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
if hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine)
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
average_output_previous = average_output
trainer.m_step(machine, data)
trainer.e_step(machine, data)
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
if hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine)
else:
logger.info("Iteration = %d ", i)
#Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None:
logger.info("log likelihood = %f", average_output)
convergence_value = abs((average_output_previous - average_output)/average_output_previous)
logger.info("Iteration = %d \t convergence value = %f ", i, convergence_value)
if convergence_value <= convergence_threshold:
logger.info("convergence value = %f",convergence_value)
#Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None and convergence_value <= convergence_threshold:
break
if hasattr(trainer,"finalize"):
trainer.finalize(machine, data)
......@@ -90,19 +89,25 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.initialize(jfa_base, data)
#V Subspace
logger.info("V subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_v(jfa_base, data)
trainer.m_step_v(jfa_base, data)
trainer.finalize_v(jfa_base, data)
#U subspace
logger.info("U subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_u(jfa_base, data)
trainer.m_step_u(jfa_base, data)
trainer.finalize_u(jfa_base, data)
# d subspace
# D subspace
logger.info("D subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_d(jfa_base, data)
trainer.m_step_d(jfa_base, data)
trainer.finalize_d(jfa_base, data)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment