Skip to content
Snippets Groups Projects
Commit b73834a6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Automatically create threadpool and allow custom trainers

parent 7d56b707
No related branches found
No related tags found
1 merge request!16Trial to implement EM with multiprocessing
Pipeline #29959 failed
......@@ -7,11 +7,12 @@
import numpy
from ._library import *
import logging
from multiprocessing.pool import ThreadPool
logger = logging.getLogger(__name__)
def _set_average(trainer, trainers, machine, data):
def _set_average(trainer, trainers, machine, data, trainer_type):
"""_set_average(trainer, data) -> None
This function computes the average of the given data and sets it to the given machine.
......@@ -33,7 +34,7 @@ def _set_average(trainer, trainers, machine, data):
Usually this list is generated by parallelizing the e-step of the ``trainer``.
"""
if isinstance(trainer, KMeansTrainer):
if trainer_type == "KMeansTrainer":
# K-Means statistics
trainer.reset_accumulators(machine)
for t in trainers:
......@@ -41,10 +42,9 @@ def _set_average(trainer, trainers, machine, data):
trainer.first_order_statistics = trainer.first_order_statistics + t.first_order_statistics
trainer.average_min_distance = trainer.average_min_distance + t.average_min_distance
#trainer.average_min_distance /= sum(d.shape[0] for d in data)
trainer.average_min_distance /= data.shape[0]
elif isinstance(trainer, (ML_GMMTrainer, MAP_GMMTrainer)):
elif trainer_type in ("ML_GMMTrainer", "MAP_GMMTrainer"):
# GMM statistics
trainer.gmm_statistics = trainers[0].gmm_statistics
for t in trainers[1:]:
......@@ -62,7 +62,7 @@ def _parallel_e_step(args):
trainer.e_step(machine, data)
def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, initialize=True, rng=None, check_inputs=True, pool=None):
def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, initialize=True, rng=None, check_inputs=True, pool=None, trainer_type=None):
"""
Trains a machine given a trainer and the proper data
......@@ -84,8 +84,15 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop
check_inputs:
Shallow checks in the inputs. Check for inf and NaN
pool : :py:class:`multiprocessing.Pool` or ``None``
If given, the provided process pool will be used to parallelize the M-step of the EM algorithm
pool : ``int`` or :py:class:`multiprocessing.Pool` or ``None``
If given, the provided process pool will be used to parallelize the M-step of the
EM algorithm. You should provide a ThreadPool not a multi process Pool. If pool is
an integer, it will be used to create a ThreadPool with that many processes.
trainer_type : ``str`` or ``None``
This is used for the parallel e_step method to see how several processes' data can
be merged into one trainer before the m_step. By default
``trainer.__class__.__name__`` is used. This is useful if you have custom trainers
and want to use this function.
"""
if check_inputs and isinstance(data, numpy.ndarray):
sum_data = numpy.sum(data)
......@@ -96,6 +103,12 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
if numpy.isnan(sum_data):
raise ValueError("Please, check your inputs; numpy.nan detected in `data` ")
if isinstance(pool, int):
pool = ThreadPool(pool)
if trainer_type is None:
trainer_type = trainer.__class__.__name__
def _e_step(trainer, machine, data):
# performs the e-step, possibly in parallel
......@@ -127,7 +140,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
# call the parallel processes
pool.map(_parallel_e_step, zip(trainers, machines, split_data))
# update the trainer with the data of the other trainers
_set_average(trainer, trainers, machine, data)
_set_average(trainer, trainers, machine, data, trainer_type)
# Initialization
if initialize:
......@@ -192,7 +205,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
else:
trainer.initialize(jfa_base, data)
#V Subspace
# V Subspace
logger.info("V subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
......@@ -200,7 +213,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.m_step_v(jfa_base, data)
trainer.finalize_v(jfa_base, data)
#U subspace
# U subspace
logger.info("U subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
......@@ -215,4 +228,3 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.e_step_d(jfa_base, data)
trainer.m_step_d(jfa_base, data)
trainer.finalize_d(jfa_base, data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment