From 0f4c5437f35594a6b83da91edc310c601035b05a Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Fri, 30 Sep 2022 16:24:11 +0200 Subject: [PATCH] Fix I-Vector crashing at mstep with many chunks. --- bob/learn/em/ivector.py | 44 +++++++++++++++---------------- bob/learn/em/test/test_ivector.py | 7 +++-- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/bob/learn/em/ivector.py b/bob/learn/em/ivector.py index cf9c8bf..927e51e 100644 --- a/bob/learn/em/ivector.py +++ b/bob/learn/em/ivector.py @@ -3,7 +3,6 @@ # @date: Fri 06 May 2022 14:18:25 UTC+02 import copy -import functools import logging import operator @@ -179,13 +178,9 @@ def e_step(machine: "IVectorMachine", data: List[GMMStats]) -> IVectorStats: return stats -def m_step( - machine: "IVectorMachine", stats: List[IVectorStats] -) -> "IVectorMachine": +def m_step(machine: "IVectorMachine", stats: IVectorStats) -> "IVectorMachine": """Updates the Machine with the maximization step of the e-m algorithm.""" - # Merge all the stats - stats = functools.reduce(operator.iadd, stats) - + logger.debug("Computing new machine parameters.") A = stats.nij_sigma_wij2.transpose((0, 2, 1)) B = stats.fnorm_sigma_wij.transpose((0, 2, 1)) @@ -274,16 +269,9 @@ class IVectorMachine(BaseEstimator): ) -> "IVectorMachine": """Trains the IVectorMachine. - Repeats the e-m steps until the convergence criterion is met or - ``max_iterations`` is reached. + Repeats the e-m steps until ``max_iterations`` is reached. """ - # if not isinstance(X[0], GMMStats): - # logger.info("Received non-GMM data. Will train it on the UBM.") - # if self.ubm._means is None: # Train a GMMMachine if not set - # logger.info("UBM not trained. Training it inside IVectorMachine.") - # self.ubm.fit(X) - # X = self.ubm.transform(X) # Transform to GMMStats chunky = False if isinstance(X, dask.bag.Bag): chunky = True @@ -308,17 +296,27 @@ class IVectorMachine(BaseEstimator): ) for xx in X ] - logger.debug(f"Computing step {step}") - new_machine = dask.compute(dask.delayed(m_step)(self, stats))[0] + + # Workaround to prevent memory issues at compute with too many chunks. + # This adds pairs of stats together instead of sending all the stats to + # one worker. + while (l := len(stats)) > 1: + last = stats[-1] + stats = [ + dask.delayed(operator.add)(stats[i], stats[l // 2 + i]) + for i in range(l // 2) + ] + if l % 2 != 0: + stats.append(last) + + stats_sum = stats[0] + new_machine = dask.compute( + dask.delayed(m_step)(self, stats_sum) + )[0] for attr in ["T", "sigma"]: setattr(self, attr, getattr(new_machine, attr)) else: - stats = [ - e_step( - machine=self, - data=X, - ) - ] + stats = e_step(machine=self, data=X) _ = m_step(self, stats) logger.info( f"IVector step {step+1:{len(str(self.max_iterations))}d}/{self.max_iterations}." diff --git a/bob/learn/em/test/test_ivector.py b/bob/learn/em/test/test_ivector.py index 13f27a2..e1a30f7 100644 --- a/bob/learn/em/test/test_ivector.py +++ b/bob/learn/em/test/test_ivector.py @@ -130,11 +130,10 @@ def test_ivector_machine_training(): test_data.sum_px = np.array([[8, 0, 4], [6, 6, 6]]) test_data.sum_pxx = np.array([[10, 20, 30], [60, 70, 80]]) projected = machine.project(test_data) - print([f"{p:.8f}" for p in projected]) proj_reference = np.array([0.94234370, -0.61558459]) - np.testing.assert_almost_equal(projected, proj_reference, decimal=7) + np.testing.assert_almost_equal(projected, proj_reference, decimal=4) def _load_references_from_file(filename): @@ -202,7 +201,7 @@ def test_trainer_nosigma(): ) # M-Step - m_step(m, [stats]) + m_step(m, stats) np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5) np.testing.assert_equal( init_sigma, m.sigma @@ -260,7 +259,7 @@ def test_trainer_update_sigma(): ) # M-Step - m_step(m, [stats]) + m_step(m, stats) np.testing.assert_almost_equal(references[it]["T"], m.T, decimal=5) np.testing.assert_almost_equal( references[it]["sigma"], m.sigma, decimal=5 -- GitLab