diff --git a/bob/learn/em/data/gmm_MAP.hdf5 b/bob/learn/em/data/gmm_MAP.hdf5 index 8106590f1721a6c6084c63bc9637b52a5b39b24e..8d3574922e7d20ca228ef146023433466dd1e2c2 100644 Binary files a/bob/learn/em/data/gmm_MAP.hdf5 and b/bob/learn/em/data/gmm_MAP.hdf5 differ diff --git a/bob/learn/em/data/gmm_ML.hdf5 b/bob/learn/em/data/gmm_ML.hdf5 index 5e667e2498f69cf59303fc30c675c9f6f7b1fcc3..4a5bd139a9a8e681d4da0d0fa4b39faac98f3972 100644 Binary files a/bob/learn/em/data/gmm_ML.hdf5 and b/bob/learn/em/data/gmm_ML.hdf5 differ diff --git a/bob/learn/em/data/gmm_ML_fitted.hdf5 b/bob/learn/em/data/gmm_ML_fitted.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..2e38a2b1ead3822d9029977eeb2894efa029edf0 Binary files /dev/null and b/bob/learn/em/data/gmm_ML_fitted.hdf5 differ diff --git a/bob/learn/em/data/gmm_ML_legacy.hdf5 b/bob/learn/em/data/gmm_ML_legacy.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..74269fe3d824aa877e59609097828ee922d5dbc4 Binary files /dev/null and b/bob/learn/em/data/gmm_ML_legacy.hdf5 differ diff --git a/bob/learn/em/data/stats.hdf5 b/bob/learn/em/data/stats.hdf5 index c4a13700ec20079fdaacbd3841e8289910e9dd82..b125212ffea531c0e29450f96e1ebe116d38e5b4 100644 Binary files a/bob/learn/em/data/stats.hdf5 and b/bob/learn/em/data/stats.hdf5 differ diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 0583f1b6fbc89dcfc042edbbb3de60db66d80038..19680f7568c68e64624bef468726184c48135ff8 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -598,7 +598,7 @@ class GMMMachine(BaseEstimator): return self.means.shape @classmethod - def from_hdf5(cls, hdf5, ubm=None): + def from_hdf5(cls, hdf5: Union[str, HDF5File], ubm: "GMMMachine" = None): """Creates a new GMMMachine object from an `HDF5File` object.""" if isinstance(hdf5, str): hdf5 = HDF5File(hdf5, "r") diff --git a/bob/learn/em/test/test_gmm.py b/bob/learn/em/test/test_gmm.py index 81793ef8bbd76a30c75e3074a784c7589781abd9..cb9ebc28cac70d69b71d6ebb2c5dfaaa65bcad21 100644 --- a/bob/learn/em/test/test_gmm.py +++ b/bob/learn/em/test/test_gmm.py @@ -294,9 +294,23 @@ def test_GMMMachine(): ) +def test_GMMMachine_legacy_loading(): + """Tests that old GMMMachine checkpoints are loaded correctly.""" + reference_file = resource_filename("bob.learn.em", "data/gmm_ML.hdf5") + legacy_gmm_file = resource_filename( + "bob.learn.em", "data/gmm_ML_legacy.hdf5" + ) + gmm = GMMMachine.from_hdf5(legacy_gmm_file) + assert isinstance(gmm, GMMMachine) + assert isinstance(gmm.n_gaussians, np.int64), type(gmm.n_gaussians) + assert isinstance(gmm.weights, np.ndarray), type(gmm.weights) + reference = GMMMachine.from_hdf5(reference_file) + np.testing.assert_allclose(gmm.variances, reference.variances) + assert gmm.is_similar_to(reference) + + def test_GMMMachine_stats(): """Tests a GMMMachine (statistics)""" - arrayset = load_array( resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5") ) @@ -802,7 +816,9 @@ def test_gmm_ML_1(): resource_filename("bob.learn.em", "data/faithful.torch3_f64.hdf5") ) gmm_ref = GMMMachine.from_hdf5( - HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r") + HDF5File( + resource_filename("bob.learn.em", "data/gmm_ML_fitted.hdf5"), "r" + ) ) for transform in (to_numpy, to_dask_array): @@ -823,8 +839,6 @@ def test_gmm_ML_1(): gmm.update_means = True gmm.update_variances = True gmm.update_weights = True - # Generate reference - # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "w")) gmm = gmm.fit(ar) @@ -911,16 +925,6 @@ def test_gmm_MAP_1(): gmmprior = GMMMachine.from_hdf5( HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r") ) - gmm = GMMMachine.from_hdf5( - HDF5File(resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r"), - ubm=gmmprior, - ) - gmm.update_means = True - gmm.update_variances = False - gmm.update_weights = False - - # Generate reference - # gmm.save(HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "w")) gmm_ref = GMMMachine.from_hdf5( HDF5File(resource_filename("bob.learn.em", "data/gmm_MAP.hdf5"), "r") @@ -928,13 +932,21 @@ def test_gmm_MAP_1(): for transform in (to_numpy, to_dask_array): ar = transform(ar) - gmm = gmm.fit(ar) - - np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=3) + gmm = GMMMachine.from_hdf5( + HDF5File( + resource_filename("bob.learn.em", "data/gmm_ML.hdf5"), "r" + ), + ubm=gmmprior, + ) + gmm.update_means = True + gmm.update_variances = False + gmm.update_weights = False + gmm.fit(ar) + np.testing.assert_almost_equal(gmm.means, gmm_ref.means, decimal=7) np.testing.assert_almost_equal( - gmm.variances, gmm_ref.variances, decimal=3 + gmm.variances, gmm_ref.variances, decimal=7 ) - np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=3) + np.testing.assert_almost_equal(gmm.weights, gmm_ref.weights, decimal=7) def test_gmm_MAP_2():