diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 762939981b58a90379ceac1d7b0e4073ba8d3f5d..38908cc7feee4ab3fa9d6f28e853534cd0c59357 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -672,12 +672,10 @@ class GMMMachine(BaseEstimator): def __eq__(self, other): return ( - np.array_equal(self.means, other.means) - and np.array_equal(self.variances, other.variances) - and np.array_equal( - self.variance_thresholds, other.variance_thresholds - ) - and np.array_equal(self.weights, other.weights) + np.allclose(self.means, other.means) + and np.allclose(self.variances, other.variances) + and np.allclose(self.variance_thresholds, other.variance_thresholds) + and np.allclose(self.weights, other.weights) ) def is_similar_to(self, other, rtol=1e-5, atol=1e-8): diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 6e4b5d0aea309140a081fbf91a2b92b9bc732935..a8962637b7ff113f6df29abd83f4f82cb98c6c72 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -14,7 +14,6 @@ import tempfile from copy import deepcopy -import dask.array as da import numpy as np from dask.distributed import Client @@ -44,6 +43,33 @@ def multiprocess_dask_client(): client.close() +def loadGMM(): + gmm = GMMMachine(n_gaussians=2) + + gmm.weights = load_array( + resource_filename("bob.learn.em", "data/gmm.init_weights.hdf5") + ) + gmm.means = load_array( + resource_filename("bob.learn.em", "data/gmm.init_means.hdf5") + ) + gmm.variances = load_array( + resource_filename("bob.learn.em", "data/gmm.init_variances.hdf5") + ) + + return gmm + + +def assert_gmm_equal(gmm1, gmm2): + """Asserts that two GMMs are equal""" + np.testing.assert_almost_equal(gmm1.weights, gmm2.weights) + np.testing.assert_almost_equal(gmm1.means, gmm2.means) + np.testing.assert_almost_equal(gmm1.variances, gmm2.variances) + np.testing.assert_almost_equal( + gmm1.variance_thresholds, gmm2.variance_thresholds + ) + assert gmm1 == gmm2 + + def test_GMMStats(): # Test a GMMStats # Initializes a GMMStats @@ -197,20 +223,16 @@ def test_GMMMachine(): gmm6.variances = variances gmm6.variance_thresholds = varianceThresholds2 - assert gmm == gmm2 + assert_gmm_equal(gmm, gmm2) assert (gmm != gmm2) is False assert gmm.is_similar_to(gmm2) assert gmm != gmm3 - assert (gmm == gmm3) is False assert gmm.is_similar_to(gmm3) is False assert gmm != gmm4 - assert (gmm == gmm4) is False assert gmm.is_similar_to(gmm4) is False assert gmm != gmm5 - assert (gmm == gmm5) is False assert gmm.is_similar_to(gmm5) is False assert gmm != gmm6 - assert (gmm == gmm6) is False assert gmm.is_similar_to(gmm6) is False # Saving and loading @@ -225,7 +247,7 @@ def test_GMMMachine(): assert type(gmm1.update_weights) is np.bool_ assert type(gmm1.trainer) is str assert gmm1.ubm is None - assert gmm == gmm1 + assert_gmm_equal(gmm, gmm1) # Using load gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians) gmm1.load(HDF5File(filename, "r")) @@ -235,13 +257,13 @@ def test_GMMMachine(): assert type(gmm1.update_weights) is np.bool_ assert type(gmm1.trainer) is str assert gmm1.ubm is None - assert gmm == gmm1 + assert_gmm_equal(gmm, gmm1) with tempfile.NamedTemporaryFile(suffix=".hdf5") as f: filename = f.name gmm.save(filename) gmm1 = GMMMachine.from_hdf5(filename) - assert gmm == gmm1 + assert_gmm_equal(gmm, gmm1) # Weights n_gaussians = 5 @@ -465,13 +487,15 @@ def test_gmm_kmeans_plusplus_init(): data = np.array( [[1.5, 1], [1, 1.5], [-1, 0.5], [-1.5, 0], [2, 2], [2.5, 2.5]] ) - machine = machine.fit(data) - expected_means = np.array([[2.25, 2.25], [-1.25, 0.25], [1.25, 1.25]]) - expected_variances = np.array( - [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]] - ) - np.testing.assert_almost_equal(machine.means, expected_means, decimal=3) - np.testing.assert_almost_equal(machine.variances, expected_variances) + for transform in (to_numpy, to_dask_array): + data = transform(data) + machine = machine.fit(data) + expected_means = np.array([[2.25, 2.25], [-1.25, 0.25], [1.25, 1.25]]) + expected_variances = np.array( + [[1 / 16, 1 / 16], [1 / 16, 1 / 16], [1 / 16, 1 / 16]] + ) + np.testing.assert_almost_equal(machine.means, expected_means, decimal=3) + np.testing.assert_almost_equal(machine.variances, expected_variances) def test_gmm_kmeans_parallel_init(): @@ -509,16 +533,18 @@ def test_likelihood(): machine = GMMMachine(n_gaussians) machine.means = np.repeat([[0], [1], [-1]], 3, 1) machine.variances = np.ones_like(machine.means) - log_likelihood = machine.log_likelihood(data) - expected_ll = np.array( - [ - -3.6519900964986527, - -3.83151883210222, - -3.83151883210222, - -5.344374066745753, - ] - ) - np.testing.assert_almost_equal(log_likelihood, expected_ll) + for transform in (to_numpy, to_dask_array): + data = transform(data) + log_likelihood = machine.log_likelihood(data) + expected_ll = np.array( + [ + -3.6519900964986527, + -3.83151883210222, + -3.83151883210222, + -5.344374066745753, + ] + ) + np.testing.assert_almost_equal(log_likelihood, expected_ll) def test_likelihood_variance(): @@ -533,16 +559,18 @@ def test_likelihood_variance(): [1, 1, 1], ] ) - log_likelihood = machine.log_likelihood(data) - expected_ll = np.array( - [ - -2.202846959440514, - -3.8699524542323793, - -4.229029034375473, - -6.940892214952679, - ] - ) - np.testing.assert_almost_equal(log_likelihood, expected_ll) + for transform in (to_numpy, to_dask_array): + data = transform(data) + log_likelihood = machine.log_likelihood(data) + expected_ll = np.array( + [ + -2.202846959440514, + -3.8699524542323793, + -4.229029034375473, + -6.940892214952679, + ] + ) + np.testing.assert_almost_equal(log_likelihood, expected_ll) def test_likelihood_weight(): @@ -552,16 +580,18 @@ def test_likelihood_weight(): machine.means = np.repeat([[0], [1], [-1]], 3, 1) machine.variances = np.ones_like(machine.means) machine.weights = [0.6, 0.1, 0.3] - log_likelihood = machine.log_likelihood(data) - expected_ll = np.array( - [ - -4.206596356117164, - -3.492325679996329, - -3.634745457950943, - -6.49485678536014, - ] - ) - np.testing.assert_almost_equal(log_likelihood, expected_ll) + for transform in (to_numpy, to_dask_array): + data = transform(data) + log_likelihood = machine.log_likelihood(data) + expected_ll = np.array( + [ + -4.206596356117164, + -3.492325679996329, + -3.634745457950943, + -6.49485678536014, + ] + ) + np.testing.assert_almost_equal(log_likelihood, expected_ll) def test_GMMMachine_object(): @@ -688,27 +718,31 @@ def test_ml_transformer(): machine.means = np.array([[2, 2, 2], [8, 8, 8]]) machine.variances = np.ones_like(machine.means) - machine = machine.fit(data) - - expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]]) - np.testing.assert_almost_equal(machine.means, expected_means) - expected_weights = np.array([2 / 5, 3 / 5]) - np.testing.assert_almost_equal(machine.weights, expected_weights) - eps = np.finfo(float).eps - expected_variances = np.array([[1 / 4, 1 / 4, eps], [eps, 2 / 3, 2 / 3]]) - np.testing.assert_almost_equal(machine.variances, expected_variances) + for transform in (to_numpy, to_dask_array): + data = transform(data) + machine = machine.fit(data) + + expected_means = np.array([[1.5, 1.5, 2.0], [7.0, 8.0, 8.0]]) + np.testing.assert_almost_equal(machine.means, expected_means) + expected_weights = np.array([2 / 5, 3 / 5]) + np.testing.assert_almost_equal(machine.weights, expected_weights) + eps = np.finfo(float).eps + expected_variances = np.array( + [[1 / 4, 1 / 4, eps], [eps, 2 / 3, 2 / 3]] + ) + np.testing.assert_almost_equal(machine.variances, expected_variances) - stats = machine.transform(test_data) + stats = machine.transform(test_data) - expected_stats = GMMStats(n_gaussians, n_features) - expected_stats.init_fields( - log_likelihood=-6755399441055685.0, - t=test_data.shape[0], - n=np.array([2, 2], dtype=float), - sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float), - sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float), - ) - assert stats.is_similar_to(expected_stats) + expected_stats = GMMStats(n_gaussians, n_features) + expected_stats.init_fields( + log_likelihood=-6755399441055685.0, + t=test_data.shape[0], + n=np.array([2, 2], dtype=float), + sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float), + sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float), + ) + assert stats.is_similar_to(expected_stats) def test_map_transformer(): @@ -732,79 +766,69 @@ def test_map_transformer(): update_weights=True, ) - machine = machine.fit(post_data) + for transform in (to_numpy, to_dask_array): + post_data = transform(post_data) + machine = machine.fit(post_data) - expected_means = np.array( - [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]] - ) - np.testing.assert_almost_equal(machine.means, expected_means) - eps = np.finfo(float).eps - expected_vars = np.array([[eps, eps, eps], [eps, eps, eps]]) - np.testing.assert_almost_equal(machine.variances, expected_vars) - expected_weights = np.array([0.46226415, 0.53773585]) - np.testing.assert_almost_equal(machine.weights, expected_weights) - - stats = machine.transform(test_data) - - expected_stats = GMMStats(n_gaussians, n_features) - expected_stats.init_fields( - log_likelihood=-1.3837590691807108e16, - t=test_data.shape[0], - n=np.array([2, 2], dtype=float), - sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float), - sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float), - ) - assert stats.is_similar_to(expected_stats) + expected_means = np.array( + [[1.83333333, 1.83333333, 2.0], [7.57142857, 8, 8]] + ) + np.testing.assert_almost_equal(machine.means, expected_means) + eps = np.finfo(float).eps + expected_vars = np.array([[eps, eps, eps], [eps, eps, eps]]) + np.testing.assert_almost_equal(machine.variances, expected_vars) + expected_weights = np.array([0.46226415, 0.53773585]) + np.testing.assert_almost_equal(machine.weights, expected_weights) + + stats = machine.transform(test_data) + + expected_stats = GMMStats(n_gaussians, n_features) + expected_stats.init_fields( + log_likelihood=-1.3837590691807108e16, + t=test_data.shape[0], + n=np.array([2, 2], dtype=float), + sum_px=np.array([[2, 2, 3], [16, 17, 17]], dtype=float), + sum_pxx=np.array([[2, 2, 5], [128, 145, 145]], dtype=float), + ) + assert stats.is_similar_to(expected_stats) # Tests from `test_em.py` -def loadGMM(): - gmm = GMMMachine(n_gaussians=2) - - gmm.weights = load_array( - resource_filename("bob.learn.em", "data/gmm.init_weights.hdf5") - ) - gmm.means = load_array( - resource_filename("bob.learn.em", "data/gmm.init_means.hdf5") - ) - gmm.variances = load_array( - resource_filename("bob.learn.em", "data/gmm.init_variances.hdf5") - ) - - return gmm - - def test_gmm_ML_1(): """Trains a GMMMachine with ML_GMMTrainer""" ar = load_array( resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5") ) - gmm = loadGMM() + gmm_ref = GMMMachine.from_hdf5( + HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r") + ) - # test rng handling - gmm.convergence_threshold = 0.001 - gmm.update_means = True - gmm.update_variances = True - gmm.update_weights = True - gmm.random_state = np.random.RandomState(seed=12345) - gmm = gmm.fit(ar) + for transform in (to_numpy, to_dask_array): + ar = transform(ar) - gmm = loadGMM() - gmm.convergence_threshold = 0.001 - gmm.update_means = True - gmm.update_variances = True - gmm.update_weights = True - gmm = gmm.fit(ar) + gmm = loadGMM() - # Generate reference - # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w")) + # test rng handling + gmm.convergence_threshold = 0.001 + gmm.update_means = True + gmm.update_variances = True + gmm.update_weights = True + gmm.random_state = np.random.RandomState(seed=12345) + gmm = gmm.fit(ar) - gmm_ref = GMMMachine.from_hdf5( - HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r") - ) - assert gmm == gmm_ref + gmm = loadGMM() + gmm.convergence_threshold = 0.001 + gmm.update_means = True + gmm.update_variances = True + gmm.update_weights = True + # Generate reference + # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w")) + + gmm = gmm.fit(ar) + + assert_gmm_equal(gmm, gmm_ref) def test_gmm_ML_2(): @@ -813,33 +837,6 @@ def test_gmm_ML_2(): resource_filename("bob.learn.em", "data/dataNormalized.hdf5") ) - # Initialize GMMMachine - gmm = GMMMachine(n_gaussians=5) - gmm.means = load_array( - resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5") - ).astype("float64") - gmm.variances = load_array( - resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5") - ).astype("float64") - gmm.weights = np.exp( - load_array( - resource_filename("bob.learn.em", "data/weightsAfterKMeans.hdf5") - ).astype("float64") - ) - - threshold = 0.001 - gmm.variance_thresholds = threshold - - # Initialize ML Trainer - gmm.mean_var_update_threshold = 0.001 - gmm.max_fitting_steps = 25 - gmm.convergence_threshold = 0.000001 - gmm.update_means = True - gmm.update_variances = True - gmm.update_weights = True - - # Run ML - gmm = gmm.fit(ar) # Test results # Load torch3vision reference meansML_ref = load_array( @@ -852,10 +849,42 @@ def test_gmm_ML_2(): resource_filename("bob.learn.em", "data/weightsAfterML.hdf5") ) - # Compare to current results - np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3) - np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3) - np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4) + for transform in (to_numpy, to_dask_array): + ar = transform(ar) + # Initialize GMMMachine + gmm = GMMMachine(n_gaussians=5) + gmm.means = load_array( + resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5") + ).astype("float64") + gmm.variances = load_array( + resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5") + ).astype("float64") + gmm.weights = np.exp( + load_array( + resource_filename( + "bob.learn.em", "data/weightsAfterKMeans.hdf5" + ) + ).astype("float64") + ) + + threshold = 0.001 + gmm.variance_thresholds = threshold + + # Initialize ML Trainer + gmm.mean_var_update_threshold = 0.001 + gmm.max_fitting_steps = 25 + gmm.convergence_threshold = 0.000001 + gmm.update_means = True + gmm.update_variances = True + gmm.update_weights = True + + # Run ML + gmm = gmm.fit(ar) + + # Compare to current results + np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3) + np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3) + np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4) def test_gmm_MAP_1(): @@ -890,8 +919,6 @@ def test_gmm_MAP_1(): gmm.update_variances = False gmm.update_weights = False - gmm = gmm.fit(ar) - # Generate reference # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "w")) @@ -899,9 +926,15 @@ def test_gmm_MAP_1(): HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "r") ) - np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3) - np.testing.assert_almost_equal(gmm.variances, gmm_ref.variances, decimal=3) - np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3) + for transform in (to_numpy, to_dask_array): + ar = transform(ar) + gmm = gmm.fit(ar) + + np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3) + np.testing.assert_almost_equal( + gmm.variances, gmm_ref.variances, decimal=3 + ) + np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3) def test_gmm_MAP_2(): @@ -934,14 +967,16 @@ def test_gmm_MAP_2(): gmm_adapted.variances = variances gmm_adapted.weights = weights - gmm_adapted = gmm_adapted.fit(data) - new_means = load_array( resource_filename("bob.learn.em", "data/new_adapted_mean.hdf5") ) - # Compare to matlab reference - np.testing.assert_allclose(new_means.T, gmm_adapted.means, rtol=1e-4) + for transform in (to_numpy, to_dask_array): + data = transform(data) + gmm_adapted = gmm_adapted.fit(data) + + # Compare to matlab reference + np.testing.assert_allclose(new_means.T, gmm_adapted.means, rtol=1e-4) def test_gmm_MAP_3(): @@ -981,9 +1016,6 @@ def test_gmm_MAP_3(): ) gmm.variance_thresholds = threshold - # Train - gmm = gmm.fit(ar) - # Test results # Load torch3vision reference meansMAP_ref = load_array( @@ -996,13 +1028,18 @@ def test_gmm_MAP_3(): resource_filename("bob.learn.em", "data/weightsAfterMAP.hdf5") ) - # Compare to current results - # Gaps are quite large. This might be explained by the fact that there is no - # adaptation of a given Gaussian in torch3 when the corresponding responsibilities - # are below the responsibilities threshold - np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1) - np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4) - np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4) + for transform in (to_numpy, to_dask_array): + ar = transform(ar) + # Train + gmm = gmm.fit(ar) + + # Compare to current results + # Gaps are quite large. This might be explained by the fact that there is no + # adaptation of a given Gaussian in torch3 when the corresponding responsibilities + # are below the responsibilities threshold + np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1) + np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4) + np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4) def test_gmm_test(): @@ -1028,126 +1065,10 @@ def test_gmm_test(): # Test against the model score_mean_ref = -1.50379e06 - score = gmm.log_likelihood(ar).sum() - score /= len(ar) - - # Compare current results to torch3vision - assert abs(score - score_mean_ref) / score_mean_ref < 1e-4 - - -def test_gmm_ML_dask(): - # Trains a GMMMachine with dask array data; compares to a reference - - ar = da.array( - load_array( - resource_filename("bob.learn.em", "data/dataNormalized.hdf5") - ) - ) - - # Initialize GMMMachine - gmm = GMMMachine(n_gaussians=5) - gmm.means = load_array( - resource_filename("bob.learn.em", "data/meansAfterKMeans.hdf5") - ).astype("float64") - gmm.variances = load_array( - resource_filename("bob.learn.em", "data/variancesAfterKMeans.hdf5") - ).astype("float64") - gmm.weights = np.exp( - load_array( - resource_filename("bob.learn.em", "data/weightsAfterKMeans.hdf5") - ).astype("float64") - ) - - threshold = 0.001 - gmm.variance_thresholds = threshold - - # Initialize ML Trainer - gmm.mean_var_update_threshold = 0.001 - gmm.max_fitting_steps = 25 - gmm.convergence_threshold = 0.00001 - gmm.update_means = True - gmm.update_variances = True - gmm.update_weights = True - - # Run ML - gmm.fit(ar) - - # Test results - # Load torch3vision reference - meansML_ref = load_array( - resource_filename("bob.learn.em", "data/meansAfterML.hdf5") - ) - variancesML_ref = load_array( - resource_filename("bob.learn.em", "data/variancesAfterML.hdf5") - ) - weightsML_ref = load_array( - resource_filename("bob.learn.em", "data/weightsAfterML.hdf5") - ) - - # Compare to current results - np.testing.assert_allclose(gmm.means, meansML_ref, atol=3e-3) - np.testing.assert_allclose(gmm.variances, variancesML_ref, atol=3e-3) - np.testing.assert_allclose(gmm.weights, weightsML_ref, atol=1e-4) - - -def test_gmm_MAP_dask(): - # Test a GMMMachine for MAP with a dask array as data. - ar = da.array( - load_array(resource_filename("bob.learn.em", "data/dataforMAP.hdf5")) - ) - - # Initialize GMMMachine - n_gaussians = 5 - prior_gmm = GMMMachine(n_gaussians) - prior_gmm.means = load_array( - resource_filename("bob.learn.em", "data/meansAfterML.hdf5") - ) - prior_gmm.variances = load_array( - resource_filename("bob.learn.em", "data/variancesAfterML.hdf5") - ) - prior_gmm.weights = load_array( - resource_filename("bob.learn.em", "data/weightsAfterML.hdf5") - ) - - threshold = 0.001 - prior_gmm.variance_thresholds = threshold - - # Initialize MAP Trainer - prior = 0.001 - accuracy = 0.00001 - gmm = GMMMachine( - n_gaussians, - trainer="map", - ubm=prior_gmm, - convergence_threshold=prior, - max_fitting_steps=1, - update_means=True, - update_variances=False, - update_weights=False, - mean_var_update_threshold=accuracy, - map_relevance_factor=None, - ) - gmm.variance_thresholds = threshold - - # Train - gmm = gmm.fit(ar) - - # Test results - # Load torch3vision reference - meansMAP_ref = load_array( - resource_filename("bob.learn.em", "data/meansAfterMAP.hdf5") - ) - variancesMAP_ref = load_array( - resource_filename("bob.learn.em", "data/variancesAfterMAP.hdf5") - ) - weightsMAP_ref = load_array( - resource_filename("bob.learn.em", "data/weightsAfterMAP.hdf5") - ) + for transform in (to_numpy, to_dask_array): + ar = transform(ar) + score = gmm.log_likelihood(ar).sum() + score /= len(ar) - # Compare to current results - # Gaps are quite large. This might be explained by the fact that there is no - # adaptation of a given Gaussian in torch3 when the corresponding responsibilities - # are below the responsibilities threshold - np.testing.assert_allclose(gmm.means, meansMAP_ref, atol=2e-1) - np.testing.assert_allclose(gmm.variances, variancesMAP_ref, atol=1e-4) - np.testing.assert_allclose(gmm.weights, weightsMAP_ref, atol=1e-4) + # Compare current results to torch3vision + assert abs(score - score_mean_ref) / score_mean_ref < 1e-4 diff --git a/bob/learn/em/test/test_kmeans.py b/bob/learn/em/test/test_kmeans.py index 8edc205ab9ecf2f6f15a8c9e4213f57cb219f804..613a562bf864b4d66b31ae5a410fc6ad6298e164 100644 --- a/bob/learn/em/test/test_kmeans.py +++ b/bob/learn/em/test/test_kmeans.py @@ -33,7 +33,7 @@ def to_dask_array(*args): for x in args: x = np.asarray(x) chunks = list(x.shape) - chunks[0] //= 2 + chunks[0] = int(np.ceil(chunks[0] / 2)) result.append(da.from_array(x, chunks=chunks)) if len(result) == 1: return result[0]