From af6fd843a1a4afff3ea1ec29b3b4a972a03c411b Mon Sep 17 00:00:00 2001
From: Elie Khoury <elie.khoury@idiap.ch>
Date: Wed, 6 May 2015 09:29:17 +0200
Subject: [PATCH] added logging info for the EM training

---
 bob/learn/em/test/test_em.py |  3 +++
 bob/learn/em/train.py        | 13 ++++++++++---
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/bob/learn/em/test/test_em.py b/bob/learn/em/test/test_em.py
index 0dc6281..97f1a2b 100644
--- a/bob/learn/em/test/test_em.py
+++ b/bob/learn/em/test/test_em.py
@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \
 
 import bob.learn.em
 
+import bob.core
+bob.core.log.setup("bob.learn.em")
+
 #, MAP_GMMTrainer
 
 def loadGMM():
diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py
index ee2ecd4..6bb18e3 100644
--- a/bob/learn/em/train.py
+++ b/bob/learn/em/train.py
@@ -6,6 +6,8 @@
 # Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
 import numpy
 import bob.learn.em
+import logging
+logger = logging.getLogger('bob.learn.em')
 
 def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
   """
@@ -42,16 +44,21 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
     average_output          = trainer.compute_likelihood(machine)
 
   for i in range(max_iterations):
+    
     average_output_previous = average_output
     trainer.m_step(machine, data)
     trainer.e_step(machine, data)
 
     if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
       average_output = trainer.compute_likelihood(machine)
-
+    else:
+      logger.info("Iteration = %d ", i)
     #Terminates if converged (and likelihood computation is set)
-    if convergence_threshold!=None and abs((average_output_previous - average_output)/average_output_previous) <= convergence_threshold:
-      break
+    if convergence_threshold!=None: 
+      convergence_value = abs((average_output_previous - average_output)/average_output_previous)
+      logger.info("Iteration = %d \t convergence value = %f ", i, convergence_value)
+      if convergence_value <= convergence_threshold:
+        break
 
   if hasattr(trainer,"finalize"):
     trainer.finalize(machine, data)
-- 
GitLab