Skip to content
Snippets Groups Projects
Commit af6fd843 authored by Elie KHOURY's avatar Elie KHOURY
Browse files

added logging info for the EM training

parent 2782bcab
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \ ...@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \
import bob.learn.em import bob.learn.em
import bob.core
bob.core.log.setup("bob.learn.em")
#, MAP_GMMTrainer #, MAP_GMMTrainer
def loadGMM(): def loadGMM():
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland # Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
import numpy import numpy
import bob.learn.em import bob.learn.em
import logging
logger = logging.getLogger('bob.learn.em')
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None): def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
""" """
...@@ -42,16 +44,21 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non ...@@ -42,16 +44,21 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
average_output = trainer.compute_likelihood(machine) average_output = trainer.compute_likelihood(machine)
for i in range(max_iterations): for i in range(max_iterations):
average_output_previous = average_output average_output_previous = average_output
trainer.m_step(machine, data) trainer.m_step(machine, data)
trainer.e_step(machine, data) trainer.e_step(machine, data)
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"): if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine) average_output = trainer.compute_likelihood(machine)
else:
logger.info("Iteration = %d ", i)
#Terminates if converged (and likelihood computation is set) #Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None and abs((average_output_previous - average_output)/average_output_previous) <= convergence_threshold: if convergence_threshold!=None:
break convergence_value = abs((average_output_previous - average_output)/average_output_previous)
logger.info("Iteration = %d \t convergence value = %f ", i, convergence_value)
if convergence_value <= convergence_threshold:
break
if hasattr(trainer,"finalize"): if hasattr(trainer,"finalize"):
trainer.finalize(machine, data) trainer.finalize(machine, data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment