Skip to content
Snippets Groups Projects
Commit 814ab98d authored by Manuel Günther's avatar Manuel Günther
Browse files

Merge branch 'master' of https://github.com/bioidiap/bob.learn.em

parents 28aba742 e98f7784
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \
import bob.learn.em
import bob.core
bob.core.log.setup("bob.learn.em")
#, MAP_GMMTrainer
def loadGMM():
......
......@@ -6,8 +6,11 @@
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
import numpy
import bob.learn.em
import logging
logger = logging.getLogger('bob.learn.em')
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
"""
Trains a machine given a trainer and the proper data
......@@ -38,21 +41,24 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
average_output = 0
average_output_previous = 0
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
if hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine)
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
average_output_previous = average_output
trainer.m_step(machine, data)
trainer.e_step(machine, data)
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"):
if hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine)
#Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None and abs((average_output_previous - average_output)/average_output_previous) <= convergence_threshold:
break
logger.info("log likelihood = %f", average_output)
convergence_value = abs((average_output_previous - average_output)/average_output_previous)
logger.info("convergence value = %f",convergence_value)
#Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None and convergence_value <= convergence_threshold:
break
if hasattr(trainer,"finalize"):
trainer.finalize(machine, data)
......@@ -83,19 +89,25 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.initialize(jfa_base, data)
#V Subspace
logger.info("V subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_v(jfa_base, data)
trainer.m_step_v(jfa_base, data)
trainer.finalize_v(jfa_base, data)
#U subspace
logger.info("U subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_u(jfa_base, data)
trainer.m_step_u(jfa_base, data)
trainer.finalize_u(jfa_base, data)
# d subspace
# D subspace
logger.info("D subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_d(jfa_base, data)
trainer.m_step_d(jfa_base, data)
trainer.finalize_d(jfa_base, data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment