Skip to content
Snippets Groups Projects
Commit 2c98e036 authored by Manuel Günther's avatar Manuel Günther
Browse files

Fixes small issues in parallel UBM script

parent b49cea41
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,7 @@ def kmeans_initialize(algorithm, extractor, limit_data = None, force = False): ...@@ -36,7 +36,7 @@ def kmeans_initialize(algorithm, extractor, limit_data = None, force = False):
def kmeans_estep(algorithm, extractor, iteration, indices, force=False): def kmeans_estep(algorithm, extractor, iteration, indices, force=False):
"""Performs a single E-step of the K-Means algorithm (parallel)""" """Performs a single E-step of the K-Means algorithm (parallel)"""
if indices[0] > indices[1]: if indices[0] >= indices[1]:
return return
fs = FileSelector.instance() fs = FileSelector.instance()
...@@ -142,7 +142,8 @@ def kmeans_mstep(algorithm, iteration, number_of_parallel_jobs, force=False, cle ...@@ -142,7 +142,8 @@ def kmeans_mstep(algorithm, iteration, number_of_parallel_jobs, force=False, cle
bob.io.base.create_directories_safe(os.path.dirname(new_machine_file)) bob.io.base.create_directories_safe(os.path.dirname(new_machine_file))
kmeans_machine.save(bob.io.base.HDF5File(new_machine_file, 'w')) kmeans_machine.save(bob.io.base.HDF5File(new_machine_file, 'w'))
# copy the k_means file in any case # copy the k_means file, when last iteration
# TODO: implement other stopping criteria
if iteration == algorithm.kmeans_training_iterations-1: if iteration == algorithm.kmeans_training_iterations-1:
shutil.copy(new_machine_file, fs.kmeans_file) shutil.copy(new_machine_file, fs.kmeans_file)
logger.info("UBM training: Wrote new KMeans machine '%s'", fs.kmeans_file) logger.info("UBM training: Wrote new KMeans machine '%s'", fs.kmeans_file)
...@@ -191,7 +192,7 @@ def gmm_initialize(algorithm, extractor, limit_data = None, force = False): ...@@ -191,7 +192,7 @@ def gmm_initialize(algorithm, extractor, limit_data = None, force = False):
def gmm_estep(algorithm, extractor, iteration, indices, force=False): def gmm_estep(algorithm, extractor, iteration, indices, force=False):
"""Performs a single E-step of the GMM training (parallel).""" """Performs a single E-step of the GMM training (parallel)."""
if indices[0] > indices[1]: if indices[0] >= indices[1]:
return return
fs = FileSelector.instance() fs = FileSelector.instance()
...@@ -199,7 +200,7 @@ def gmm_estep(algorithm, extractor, iteration, indices, force=False): ...@@ -199,7 +200,7 @@ def gmm_estep(algorithm, extractor, iteration, indices, force=False):
new_machine_file = fs.gmm_intermediate_file(iteration + 1) new_machine_file = fs.gmm_intermediate_file(iteration + 1)
if utils.check_file(stats_file, force, 1000) or utils.check_file(new_machine_file, force, 1000): if utils.check_file(stats_file, force, 1000) or utils.check_file(new_machine_file, force, 1000):
loggerinfo("UBM training: Skipping GMM E-Step since the file '%s' or '%s' already exists", stats_file, new_machine_file) logger.info("UBM training: Skipping GMM E-Step since the file '%s' or '%s' already exists", stats_file, new_machine_file)
else: else:
training_list = fs.training_list('extracted', 'train_projector') training_list = fs.training_list('extracted', 'train_projector')
last_machine_file = fs.gmm_intermediate_file(iteration) last_machine_file = fs.gmm_intermediate_file(iteration)
...@@ -268,6 +269,8 @@ def gmm_mstep(algorithm, iteration, number_of_parallel_jobs, force=False, clean= ...@@ -268,6 +269,8 @@ def gmm_mstep(algorithm, iteration, number_of_parallel_jobs, force=False, clean=
bob.io.base.create_directories_safe(os.path.dirname(new_machine_file)) bob.io.base.create_directories_safe(os.path.dirname(new_machine_file))
gmm_machine.save(bob.io.base.HDF5File(new_machine_file, 'w')) gmm_machine.save(bob.io.base.HDF5File(new_machine_file, 'w'))
# Write the final UBM file after the last iteration
# TODO: implement other stopping criteria
if iteration == algorithm.gmm_training_iterations-1: if iteration == algorithm.gmm_training_iterations-1:
shutil.copy(new_machine_file, fs.ubm_file) shutil.copy(new_machine_file, fs.ubm_file)
logger.info("UBM training: Wrote new UBM '%s'", fs.ubm_file) logger.info("UBM training: Wrote new UBM '%s'", fs.ubm_file)
......
...@@ -33,10 +33,10 @@ ...@@ -33,10 +33,10 @@
# allows you to test your package with new python dependencies w/o requiring # allows you to test your package with new python dependencies w/o requiring
# administrative interventions. # administrative interventions.
from setuptools import setup, find_packages, dist from setuptools import setup, dist
dist.Distribution(dict(setup_requires=['bob.extension'])) dist.Distribution(dict(setup_requires=['bob.extension']))
from bob.extension.utils import load_requirements from bob.extension.utils import load_requirements, find_packages
install_requires = load_requirements() install_requires = load_requirements()
# The only thing we do in this file is to call the setup() function with all # The only thing we do in this file is to call the setup() function with all
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment