Skip to content
Snippets Groups Projects
Commit 0f4c5437 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Fix I-Vector crashing at mstep with many chunks.

parent 4a9143c1
Branches
Tags
1 merge request!60Port of I-Vector to python
Pipeline #65241 passed
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# @date: Fri 06 May 2022 14:18:25 UTC+02 # @date: Fri 06 May 2022 14:18:25 UTC+02
import copy import copy
import functools
import logging import logging
import operator import operator
...@@ -179,13 +178,9 @@ def e_step(machine: "IVectorMachine", data: List[GMMStats]) -> IVectorStats: ...@@ -179,13 +178,9 @@ def e_step(machine: "IVectorMachine", data: List[GMMStats]) -> IVectorStats:
return stats return stats
def m_step( def m_step(machine: "IVectorMachine", stats: IVectorStats) -> "IVectorMachine":
machine: "IVectorMachine", stats: List[IVectorStats]
) -> "IVectorMachine":
"""Updates the Machine with the maximization step of the e-m algorithm.""" """Updates the Machine with the maximization step of the e-m algorithm."""
# Merge all the stats logger.debug("Computing new machine parameters.")
stats = functools.reduce(operator.iadd, stats)
A = stats.nij_sigma_wij2.transpose((0, 2, 1)) A = stats.nij_sigma_wij2.transpose((0, 2, 1))
B = stats.fnorm_sigma_wij.transpose((0, 2, 1)) B = stats.fnorm_sigma_wij.transpose((0, 2, 1))
...@@ -274,16 +269,9 @@ class IVectorMachine(BaseEstimator): ...@@ -274,16 +269,9 @@ class IVectorMachine(BaseEstimator):
) -> "IVectorMachine": ) -> "IVectorMachine":
"""Trains the IVectorMachine. """Trains the IVectorMachine.
Repeats the e-m steps until the convergence criterion is met or Repeats the e-m steps until ``max_iterations`` is reached.
``max_iterations`` is reached.
""" """
# if not isinstance(X[0], GMMStats):
# logger.info("Received non-GMM data. Will train it on the UBM.")
# if self.ubm._means is None: # Train a GMMMachine if not set
# logger.info("UBM not trained. Training it inside IVectorMachine.")
# self.ubm.fit(X)
# X = self.ubm.transform(X) # Transform to GMMStats
chunky = False chunky = False
if isinstance(X, dask.bag.Bag): if isinstance(X, dask.bag.Bag):
chunky = True chunky = True
...@@ -308,17 +296,27 @@ class IVectorMachine(BaseEstimator): ...@@ -308,17 +296,27 @@ class IVectorMachine(BaseEstimator):
) )
for xx in X for xx in X
] ]
logger.debug(f"Computing step {step}")
new_machine = dask.compute(dask.delayed(m_step)(self, stats))[0] # Workaround to prevent memory issues at compute with too many chunks.
# This adds pairs of stats together instead of sending all the stats to
# one worker.
while (l := len(stats)) > 1:
last = stats[-1]
stats = [
dask.delayed(operator.add)(stats[i], stats[l // 2 + i])
for i in range(l // 2)
]
if l % 2 != 0:
stats.append(last)
stats_sum = stats[0]
new_machine = dask.compute(
dask.delayed(m_step)(self, stats_sum)
)[0]
for attr in ["T", "sigma"]: for attr in ["T", "sigma"]:
setattr(self, attr, getattr(new_machine, attr)) setattr(self, attr, getattr(new_machine, attr))
else: else:
stats = [ stats = e_step(machine=self, data=X)
e_step(
machine=self,
data=X,
)
]
_ = m_step(self, stats) _ = m_step(self, stats)
logger.info( logger.info(
f"IVector step {step+1:{len(str(self.max_iterations))}d}/{self.max_iterations}." f"IVector step {step+1:{len(str(self.max_iterations))}d}/{self.max_iterations}."
......
...@@ -130,11 +130,10 @@ def test_ivector_machine_training(): ...@@ -130,11 +130,10 @@ def test_ivector_machine_training():
test_data.sum_px = np.array([[8, 0, 4], [6, 6, 6]]) test_data.sum_px = np.array([[8, 0, 4], [6, 6, 6]])
test_data.sum_pxx = np.array([[10, 20, 30], [60, 70, 80]]) test_data.sum_pxx = np.array([[10, 20, 30], [60, 70, 80]])
projected = machine.project(test_data) projected = machine.project(test_data)
print([f"{p:.8f}" for p in projected])
proj_reference = np.array([0.94234370, -0.61558459]) proj_reference = np.array([0.94234370, -0.61558459])
np.testing.assert_almost_equal(projected, proj_reference, decimal=7) np.testing.assert_almost_equal(projected, proj_reference, decimal=4)
def _load_references_from_file(filename): def _load_references_from_file(filename):
...@@ -202,7 +201,7 @@ def test_trainer_nosigma(): ...@@ -202,7 +201,7 @@ def test_trainer_nosigma():
) )
# M-Step # M-Step
m_step(m, [stats]) m_step(m, stats)
np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5) np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5)
np.testing.assert_equal( np.testing.assert_equal(
init_sigma, m.sigma init_sigma, m.sigma
...@@ -260,7 +259,7 @@ def test_trainer_update_sigma(): ...@@ -260,7 +259,7 @@ def test_trainer_update_sigma():
) )
# M-Step # M-Step
m_step(m, [stats]) m_step(m, stats)
np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5) np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5)
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
references[it]["sigma"], m.sigma, decimal=5 references[it]["sigma"], m.sigma, decimal=5
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment