diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py index 6bb18e3451d82a2aee524ed7c5267ab6326e0d9e..e5c2f074c6a8054780dfc0e11858fc16ea004c8a 100644 --- a/bob/learn/em/train.py +++ b/bob/learn/em/train.py @@ -10,6 +10,7 @@ import logging logger = logging.getLogger('bob.learn.em') def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None): + """ Trains a machine given a trainer and the proper data @@ -40,26 +41,24 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non average_output = 0 average_output_previous = 0 - if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"): + if hasattr(trainer,"compute_likelihood"): average_output = trainer.compute_likelihood(machine) for i in range(max_iterations): - + logger.info("Iteration = %d/%d", i, max_iterations) average_output_previous = average_output trainer.m_step(machine, data) trainer.e_step(machine, data) - - if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"): + + if hasattr(trainer,"compute_likelihood"): average_output = trainer.compute_likelihood(machine) - else: - logger.info("Iteration = %d ", i) - #Terminates if converged (and likelihood computation is set) - if convergence_threshold!=None: + logger.info("log likelihood = %f", average_output) convergence_value = abs((average_output_previous - average_output)/average_output_previous) - logger.info("Iteration = %d \t convergence value = %f ", i, convergence_value) - if convergence_value <= convergence_threshold: + logger.info("convergence value = %f",convergence_value) + + #Terminates if converged (and likelihood computation is set) + if convergence_threshold!=None and convergence_value <= convergence_threshold: break - if hasattr(trainer,"finalize"): trainer.finalize(machine, data) @@ -90,19 +89,25 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N trainer.initialize(jfa_base, data) #V Subspace + logger.info("V subspace estimation...") for i in range(max_iterations): + logger.info("Iteration = %d/%d", i, max_iterations) trainer.e_step_v(jfa_base, data) trainer.m_step_v(jfa_base, data) trainer.finalize_v(jfa_base, data) #U subspace + logger.info("U subspace estimation...") for i in range(max_iterations): + logger.info("Iteration = %d/%d", i, max_iterations) trainer.e_step_u(jfa_base, data) trainer.m_step_u(jfa_base, data) trainer.finalize_u(jfa_base, data) - # d subspace + # D subspace + logger.info("D subspace estimation...") for i in range(max_iterations): + logger.info("Iteration = %d/%d", i, max_iterations) trainer.e_step_d(jfa_base, data) trainer.m_step_d(jfa_base, data) trainer.finalize_d(jfa_base, data)