Commit b73834a6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Automatically create threadpool and allow custom trainers

parent 7d56b707
Pipeline #29959 failed with stage
in 13 minutes and 24 seconds
......@@ -7,11 +7,12 @@
import numpy
from ._library import *
import logging
from multiprocessing.pool import ThreadPool
logger = logging.getLogger(__name__)
def _set_average(trainer, trainers, machine, data):
def _set_average(trainer, trainers, machine, data, trainer_type):
"""_set_average(trainer, data) -> None
This function computes the average of the given data and sets it to the given machine.
......@@ -33,7 +34,7 @@ def _set_average(trainer, trainers, machine, data):
Usually this list is generated by parallelizing the e-step of the ``trainer``.
"""
if isinstance(trainer, KMeansTrainer):
if trainer_type == "KMeansTrainer":
# K-Means statistics
trainer.reset_accumulators(machine)
for t in trainers:
......@@ -41,10 +42,9 @@ def _set_average(trainer, trainers, machine, data):
trainer.first_order_statistics = trainer.first_order_statistics + t.first_order_statistics
trainer.average_min_distance = trainer.average_min_distance + t.average_min_distance
#trainer.average_min_distance /= sum(d.shape[0] for d in data)
trainer.average_min_distance /= data.shape[0]
elif isinstance(trainer, (ML_GMMTrainer, MAP_GMMTrainer)):
elif trainer_type in ("ML_GMMTrainer", "MAP_GMMTrainer"):
# GMM statistics
trainer.gmm_statistics = trainers[0].gmm_statistics
for t in trainers[1:]:
......@@ -62,7 +62,7 @@ def _parallel_e_step(args):
trainer.e_step(machine, data)
def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, initialize=True, rng=None, check_inputs=True, pool=None):
def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, initialize=True, rng=None, check_inputs=True, pool=None, trainer_type=None):
"""
Trains a machine given a trainer and the proper data
......@@ -84,8 +84,15 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop
check_inputs:
Shallow checks in the inputs. Check for inf and NaN
pool : :py:class:`multiprocessing.Pool` or ``None``
If given, the provided process pool will be used to parallelize the M-step of the EM algorithm
pool : ``int`` or :py:class:`multiprocessing.Pool` or ``None``
If given, the provided process pool will be used to parallelize the M-step of the
EM algorithm. You should provide a ThreadPool not a multi process Pool. If pool is
an integer, it will be used to create a ThreadPool with that many processes.
trainer_type : ``str`` or ``None``
This is used for the parallel e_step method to see how several processes' data can
be merged into one trainer before the m_step. By default
``trainer.__class__.__name__`` is used. This is useful if you have custom trainers
and want to use this function.
"""
if check_inputs and isinstance(data, numpy.ndarray):
sum_data = numpy.sum(data)
......@@ -96,6 +103,12 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
if numpy.isnan(sum_data):
raise ValueError("Please, check your inputs; numpy.nan detected in `data` ")
if isinstance(pool, int):
pool = ThreadPool(pool)
if trainer_type is None:
trainer_type = trainer.__class__.__name__
def _e_step(trainer, machine, data):
# performs the e-step, possibly in parallel
......@@ -127,7 +140,7 @@ def train(trainer, machine, data, max_iterations=50, convergence_threshold=None,
# call the parallel processes
pool.map(_parallel_e_step, zip(trainers, machines, split_data))
# update the trainer with the data of the other trainers
_set_average(trainer, trainers, machine, data)
_set_average(trainer, trainers, machine, data, trainer_type)
# Initialization
if initialize:
......@@ -192,7 +205,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
else:
trainer.initialize(jfa_base, data)
#V Subspace
# V Subspace
logger.info("V subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
......@@ -200,7 +213,7 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.m_step_v(jfa_base, data)
trainer.finalize_v(jfa_base, data)
#U subspace
# U subspace
logger.info("U subspace estimation...")
for i in range(max_iterations):
logger.info("Iteration = %d/%d", i, max_iterations)
......@@ -215,4 +228,3 @@ def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=N
trainer.e_step_d(jfa_base, data)
trainer.m_step_d(jfa_base, data)
trainer.finalize_d(jfa_base, data)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment