From af6fd843a1a4afff3ea1ec29b3b4a972a03c411b Mon Sep 17 00:00:00 2001 From: Elie Khoury <elie.khoury@idiap.ch> Date: Wed, 6 May 2015 09:29:17 +0200 Subject: [PATCH] added logging info for the EM training --- bob/learn/em/test/test_em.py | 3 +++ bob/learn/em/train.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/bob/learn/em/test/test_em.py b/bob/learn/em/test/test_em.py index 0dc6281..97f1a2b 100644 --- a/bob/learn/em/test/test_em.py +++ b/bob/learn/em/test/test_em.py @@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \ import bob.learn.em +import bob.core +bob.core.log.setup("bob.learn.em") + #, MAP_GMMTrainer def loadGMM(): diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py index ee2ecd4..6bb18e3 100644 --- a/bob/learn/em/train.py +++ b/bob/learn/em/train.py @@ -6,6 +6,8 @@ # Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland import numpy import bob.learn.em +import logging +logger = logging.getLogger('bob.learn.em') def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None): """ @@ -42,16 +44,21 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non average_output = trainer.compute_likelihood(machine) for i in range(max_iterations): + average_output_previous = average_output trainer.m_step(machine, data) trainer.e_step(machine, data) if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"): average_output = trainer.compute_likelihood(machine) - + else: + logger.info("Iteration = %d ", i) #Terminates if converged (and likelihood computation is set) - if convergence_threshold!=None and abs((average_output_previous - average_output)/average_output_previous) <= convergence_threshold: - break + if convergence_threshold!=None: + convergence_value = abs((average_output_previous - average_output)/average_output_previous) + logger.info("Iteration = %d \t convergence value = %f ", i, convergence_value) + if convergence_value <= convergence_threshold: + break if hasattr(trainer,"finalize"): trainer.finalize(machine, data) -- GitLab