Commit 2ea663bf authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Add input checks back and enable tests

parent 607e6a23
Pipeline #29920 failed with stage
in 11 minutes and 17 seconds
......@@ -184,11 +184,10 @@ def test_trainer_execption():
machine = KMeansMachine(2, 2)
data = numpy.array([[1.0, 2.0], [2, 3.], [1, 1.], [2, 5.], [numpy.inf, 1.0]])
trainer = KMeansTrainer()
bob.learn.em.train(trainer, machine, data, 10)
#assert_raises(ValueError, bob.learn.em.train, trainer, machine, data, 10)
assert_raises(ValueError, bob.learn.em.train, trainer, machine, data, 10)
# Testing Nan
machine = KMeansMachine(2, 2)
data = numpy.array([[1.0, 2.0], [2, 3.], [1, numpy.nan], [2, 5.], [2.0, 1.0]])
trainer = KMeansTrainer()
#assert_raises(ValueError, bob.learn.em.train, trainer, machine, data, 10)
assert_raises(ValueError, bob.learn.em.train, trainer, machine, data, 10)
......@@ -7,10 +7,9 @@
import numpy
from ._library import *
import logging
import bob.learn.em
logger = logging.getLogger(__name__)
logger = logging.getLogger('bob.learn.em')
def _set_average(trainer, trainers, data):
"""_set_average(trainer, data) -> None
......@@ -46,7 +45,7 @@ def _set_average(trainer, trainers, data):
trainer.average_min_distance += trainer.average_min_distance
trainer.average_min_distance /= sum(d.shape[0] for d in data)
elif isinstance(trainer, ML_GMMTrainer):
elif isinstance(trainer, (ML_GMMTrainer, MAP_GMMTrainer)):
# GMM statistics
trainer.gmm_statistics = trainers[0].gmm_statistics
for t in trainers[1:]:
......@@ -56,7 +55,6 @@ def _set_average(trainer, trainers, data):
raise NotImplementedError("Implement Me!")
def _parallel_e_step(args):
"""This function applies the e_step of the given trainer (first argument) on the given data (second argument).
It is called by each parallel process.
......@@ -65,7 +63,7 @@ def _parallel_e_step(args):
trainer.e_step(machine, data)
def train(trainer, machine, data, max_iterations = 50, convergence_threshold=None, initialize=True, rng=None, pool=None):
def train(trainer, machine, data, max_iterations=50, convergence_threshold=None, initialize=True, rng=None, check_inputs=True, pool=None):
"""
Trains a machine given a trainer and the proper data
......@@ -85,9 +83,20 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
If True, runs the initialization procedure
rng : :py:class:`bob.core.random.mt19937`
The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loop
check_inputs:
Shallow checks in the inputs. Check for inf and NaN
pool : :py:class:`multiprocessing.Pool` or ``None``
If given, the provided process pool will be used to parallelize the M-step of the EM algorithm
"""
data = numpy.asarray(data)
if check_inputs:
sum_data = numpy.sum(data)
if numpy.isinf(sum_data):
raise ValueError("Please, check your inputs; numpy.inf detected in `data` ")
if numpy.isnan(sum_data):
raise ValueError("Please, check your inputs; numpy.nan detected in `data` ")
def _e_step(trainer, machine, data):
......@@ -97,38 +106,39 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non
trainer.e_step(machine, data)
else:
# use the given process pool
processes = pool._processes
# use the given process pool
n_processes = pool._processes
# Mapping references of the data
split_data = []
offset = 0
step = data.shape[0]//processes
for p in range(processes):
if p == processes-1:
step = int(data.shape[0] // n_processes)
for p in range(n_processes):
if p == n_processes - 1:
# take all the data in the last chunk
split_data.append(data[offset:])
else:
split_data.append(data[offset:offset+step])
split_data.append(data[offset: offset + step])
offset += step
# create trainers for each process
trainers = [trainer.__class__(trainer) for p in range(processes)]
machines = [machine.__class__(machine) for p in range(processes)]
trainers = [trainer.__class__(trainer) for p in range(n_processes)]
# no need to copy the machines
machines = [machine for p in range(n_processes)]
# call the parallel processes
pool.map(_parallel_e_step, zip(trainers, machines, split_data))
# update the trainer with the data of the other trainers
_set_average(trainer, trainers, data)
#Initialization
# Initialization
if initialize:
if rng is not None:
trainer.initialize(machine, data, rng)
else:
trainer.initialize(machine, data)
_e_step(trainer, machine,data)
_e_step(trainer, machine, data)
average_output = 0
average_output_previous = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment