Skip to content
Snippets Groups Projects
Commit 8121f54a authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Fix IVector dask bags test

parent 67962115
No related branches found
No related tags found
1 merge request!60Port of I-Vector to python
...@@ -183,18 +183,21 @@ def m_step( ...@@ -183,18 +183,21 @@ def m_step(
machine: "IVectorMachine", stats: List[IVectorStats] machine: "IVectorMachine", stats: List[IVectorStats]
) -> "IVectorMachine": ) -> "IVectorMachine":
"""Updates the Machine with the maximization step of the e-m algorithm.""" """Updates the Machine with the maximization step of the e-m algorithm."""
# Merge all the stats
stats = functools.reduce(operator.iadd, stats) stats = functools.reduce(operator.iadd, stats)
A = stats.nij_sigma_wij2.transpose((0, 2, 1)) A = stats.nij_sigma_wij2.transpose((0, 2, 1))
B = stats.fnorm_sigma_wij.transpose((0, 2, 1)) B = stats.fnorm_sigma_wij.transpose((0, 2, 1))
# Default value of X if any of A[c] is 0
X = np.zeros_like(B) X = np.zeros_like(B)
# Solve for all A != 0 # Solve for all A[c] != 0
if any(mask := A.any(axis=(-2, -1))): if any(mask := A.any(axis=(-2, -1))): # Prevents solving with 0 matrices
X[mask] = [ X[mask] = [
np.linalg.solve(A[c], B[c]) for c in range(len(mask)) if A[c].any() np.linalg.solve(A[c], B[c]) for c in range(len(mask)) if A[c].any()
] ]
# Update the machine
machine.T = X.transpose((0, 2, 1)) machine.T = X.transpose((0, 2, 1))
if machine.update_sigma: if machine.update_sigma:
...@@ -309,9 +312,6 @@ class IVectorMachine(BaseEstimator): ...@@ -309,9 +312,6 @@ class IVectorMachine(BaseEstimator):
new_machine = dask.compute(dask.delayed(m_step)(self, stats))[0] new_machine = dask.compute(dask.delayed(m_step)(self, stats))[0]
for attr in ["T", "sigma"]: for attr in ["T", "sigma"]:
setattr(self, attr, getattr(new_machine, attr)) setattr(self, attr, getattr(new_machine, attr))
self.T.persist()
self.sigma.persist()
else: else:
stats = [ stats = [
e_step( e_step(
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import contextlib import contextlib
import copy import copy
import dask.bag
import dask.distributed import dask.distributed
import numpy as np import numpy as np
...@@ -13,8 +14,7 @@ from pkg_resources import resource_filename ...@@ -13,8 +14,7 @@ from pkg_resources import resource_filename
from bob.learn.em import GMMMachine, GMMStats, IVectorMachine from bob.learn.em import GMMMachine, GMMStats, IVectorMachine
from bob.learn.em.ivector import e_step, m_step from bob.learn.em.ivector import e_step, m_step
from bob.learn.em.test.test_kmeans import to_numpy
from .test_kmeans import to_dask_array, to_numpy
@contextlib.contextmanager @contextlib.contextmanager
...@@ -27,6 +27,17 @@ def _dask_distributed_context(): ...@@ -27,6 +27,17 @@ def _dask_distributed_context():
client.close() client.close()
def to_dask_bag(*args):
"""Converts all args into dask Bags."""
result = []
for x in args:
x = np.asarray(x)
result.append(dask.bag.from_sequence(x, npartitions=x.shape[0] * 2))
if len(result) == 1:
return result[0]
return result
def test_ivector_machine_base(): def test_ivector_machine_base():
# Create the UBM and set its values manually # Create the UBM and set its values manually
ubm = GMMMachine(n_gaussians=2) ubm = GMMMachine(n_gaussians=2)
...@@ -283,21 +294,22 @@ def test_ivector_fit(): ...@@ -283,21 +294,22 @@ def test_ivector_fit():
# Serial test # Serial test
np.random.seed(0) np.random.seed(0)
fit_data = to_numpy(fit_data) fit_data = to_numpy(fit_data)
projected_data = ubm.transform(d for d in fit_data) projected_data = ubm.transform(fit_data)
m = IVectorMachine(ubm=ubm, dim_t=2, max_iterations=2) m = IVectorMachine(ubm=ubm, dim_t=2, max_iterations=2)
m.fit(d for d in projected_data) m.fit(projected_data)
result = m.transform(ubm.transform(d for d in test_data)) result = m.transform(ubm.transform(test_data))
np.testing.assert_almost_equal(result, reference_result, decimal=5) np.testing.assert_almost_equal(result, reference_result, decimal=5)
# Parallel test # Parallel test
with _dask_distributed_context(): with _dask_distributed_context():
for transform in [to_numpy, to_dask_array]: for transform in [to_numpy, to_dask_bag]:
np.random.seed(0) np.random.seed(0)
fit_data = transform(fit_data) fit_data = transform(fit_data)
projected_data = ubm.transform(d for d in fit_data) projected_data = ubm.transform(fit_data)
projected_data = transform(projected_data)
m = IVectorMachine(ubm=ubm, dim_t=2, max_iterations=2) m = IVectorMachine(ubm=ubm, dim_t=2, max_iterations=2)
m.fit(d for d in projected_data) m.fit(projected_data)
result = m.transform(d for d in test_data) result = m.transform(ubm.transform(test_data))
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
np.array(result), reference_result, decimal=5 np.array(result), reference_result, decimal=5
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment