Commit af6fd843 authored by Elie KHOURY's avatar Elie KHOURY
Browse files

added logging info for the EM training

parent 2782bcab
......@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \
import bob.learn.em
import bob.core
bob.core.log.setup("bob.learn.em")
#, MAP_GMMTrainer
def loadGMM():
......
......@@ -6,6 +6,8 @@
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
import numpy
import bob.learn.em
import logging
logger = logging.getLogger('bob.learn.em')
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
"""
......@@ -42,16 +44,21 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
average_output = trainer.compute_likelihood(machine)
for i in range(max_iterations):
average_output_previous = average_output
trainer.m_step(machine, data)
trainer.e_step(machine, data)
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine)
else:
logger.info("Iteration = %d ", i)
#Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None and abs((average_output_previous - average_output)/average_output_previous) <= convergence_threshold:
break
if convergence_threshold!=None:
convergence_value = abs((average_output_previous - average_output)/average_output_previous)
logger.info("Iteration = %d \t convergence value = %f ", i, convergence_value)
if convergence_value <= convergence_threshold:
break
if hasattr(trainer,"finalize"):
trainer.finalize(machine, data)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment