From 6b84f7ddbb9524e4d2b7e88ad9d350175355dfb0 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 16 Mar 2022 11:18:32 +0100 Subject: [PATCH] [gmm] fix gmm training when dask workers are in a different process --- bob/learn/em/gmm.py | 8 +++++--- bob/learn/em/test/test_gmm.py | 26 +++++++++++++++++++------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 436d169..88efe81 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -119,7 +119,7 @@ def m_step( relevance_factor=map_relevance_factor, ) average_output = float(statistics.log_likelihood / statistics.t) - return average_output + return machine, average_output class GMMStats: @@ -888,9 +888,11 @@ class GMMMachine(BaseEstimator): ) for xx in X ] - average_output = dask.compute( + new_machine, average_output = dask.compute( dask.delayed(m_step_func)(self, stats) )[0] + for attr in ["weights", "means", "variances"]: + setattr(self, attr, getattr(new_machine, attr)) else: stats = [ e_step( @@ -902,7 +904,7 @@ class GMMMachine(BaseEstimator): log_weights=self.log_weights, ) ] - average_output = m_step_func(self, stats) + _, average_output = m_step_func(self, stats) logger.debug(f"log likelihood = {average_output}") if step > 1: diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 7eda1c6..3ad780a 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -16,11 +16,14 @@ from copy import deepcopy import dask.array as da import numpy as np +from dask.distributed import Client from h5py import File as HDF5File from pkg_resources import resource_filename from bob.learn.em import GMMMachine, GMMStats, KMeansMachine +from .test_kmeans import to_dask_array, to_numpy + def load_array(filename): with HDF5File(filename, "r") as f: @@ -464,13 +467,22 @@ def test_gmm_kmeans_parallel_init(): data = np.array( [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]] ) - machine = machine.fit(data) - expected_means = np.array([[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]]) - expected_variances = np.array( - [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]] - ) - np.testing.assert_almost_equal(machine.means, expected_means, decimal=3) - np.testing.assert_almost_equal(machine.variances, expected_variances) + with Client().as_current(): + for transform in (to_numpy, to_dask_array): + data = transform(data) + machine = machine.fit(data) + expected_means = np.array( + [[1.25, 1.25], [-1.25, 0.25], [2.25, 2.25]] + ) + expected_variances = np.array( + [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]] + ) + np.testing.assert_almost_equal( + machine.means, expected_means, decimal=3 + ) + np.testing.assert_almost_equal( + machine.variances, expected_variances + ) def test_likelihood(): -- GitLab