Skip to content
Snippets Groups Projects
Commit 814ab98d authored by Manuel Günther's avatar Manuel Günther
Browse files

Merge branch 'master' of https://github.com/bioidiap/bob.learn.em

parents 28aba742 e98f7784
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \ ...@@ -18,6 +18,9 @@ from bob.learn.em import KMeansMachine, GMMMachine, KMeansTrainer, \
import bob.learn.em import bob.learn.em
import bob.core
bob.core.log.setup("bob.learn.em")
#, MAP_GMMTrainer #, MAP_GMMTrainer
def loadGMM(): def loadGMM():
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
# Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland # Copyright (C) 2011-2015 Idiap Research Institute, Martigny, Switzerland
import numpy import numpy
import bob.learn.em import bob.learn.em
import logging
logger = logging.getLogger('bob.learn.em')
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None): def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None):
""" """
Trains a machine given a trainer and the proper data Trains a machine given a trainer and the proper data
...@@ -38,21 +41,24 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non ...@@ -38,21 +41,24 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
average_output = 0 average_output = 0
average_output_previous = 0 average_output_previous = 0
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"): if hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine) average_output = trainer.compute_likelihood(machine)
for i in range(max_iterations): for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
average_output_previous = average_output average_output_previous = average_output
trainer.m_step(machine, data) trainer.m_step(machine, data)
trainer.e_step(machine, data) trainer.e_step(machine, data)
if convergence_threshold!=None and hasattr(trainer,"compute_likelihood"): if hasattr(trainer,"compute_likelihood"):
average_output = trainer.compute_likelihood(machine) average_output = trainer.compute_likelihood(machine)
logger.info("log likelihood = %f", average_output)
#Terminates if converged (and likelihood computation is set) convergence_value = abs((average_output_previous - average_output)/average_output_previous)
if convergence_threshold!=None and abs((average_output_previous - average_output)/average_output_previous) <= convergence_threshold: logger.info("convergence value = %f",convergence_value)
break
#Terminates if converged (and likelihood computation is set)
if convergence_threshold!=None and convergence_value <= convergence_threshold:
break
if hasattr(trainer,"finalize"): if hasattr(trainer,"finalize"):
trainer.finalize(machine, data) trainer.finalize(machine, data)
...@@ -83,19 +89,25 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N ...@@ -83,19 +89,25 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.initialize(jfa_base, data) trainer.initialize(jfa_base, data)
#V Subspace #V Subspace
logger.info("V subspace estimation...")
for i in range(max_iterations): for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_v(jfa_base, data) trainer.e_step_v(jfa_base, data)
trainer.m_step_v(jfa_base, data) trainer.m_step_v(jfa_base, data)
trainer.finalize_v(jfa_base, data) trainer.finalize_v(jfa_base, data)
#U subspace #U subspace
logger.info("U subspace estimation...")
for i in range(max_iterations): for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_u(jfa_base, data) trainer.e_step_u(jfa_base, data)
trainer.m_step_u(jfa_base, data) trainer.m_step_u(jfa_base, data)
trainer.finalize_u(jfa_base, data) trainer.finalize_u(jfa_base, data)
# d subspace # D subspace
logger.info("D subspace estimation...")
for i in range(max_iterations): for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
trainer.e_step_d(jfa_base, data) trainer.e_step_d(jfa_base, data)
trainer.m_step_d(jfa_base, data) trainer.m_step_d(jfa_base, data)
trainer.finalize_d(jfa_base, data) trainer.finalize_d(jfa_base, data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment