Skip to content
Snippets Groups Projects

Fixes to kmeans and gmm

Merged Yannick DAYER requested to merge params into master
3 files
+ 38
25
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -202,6 +202,7 @@ def test_GMMMachine():
@@ -202,6 +202,7 @@ def test_GMMMachine():
with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
filename = f.name
filename = f.name
gmm.save(HDF5File(filename, "w"))
gmm.save(HDF5File(filename, "w"))
 
# Using from_hdf5
gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r"))
gmm1 = GMMMachine.from_hdf5(HDF5File(filename, "r"))
assert type(gmm1.n_gaussians) is np.int64
assert type(gmm1.n_gaussians) is np.int64
assert type(gmm1.update_means) is np.bool_
assert type(gmm1.update_means) is np.bool_
@@ -210,6 +211,17 @@ def test_GMMMachine():
@@ -210,6 +211,17 @@ def test_GMMMachine():
assert type(gmm1.trainer) is str
assert type(gmm1.trainer) is str
assert gmm1.ubm is None
assert gmm1.ubm is None
assert gmm == gmm1
assert gmm == gmm1
 
# Using load
 
gmm1 = GMMMachine(n_gaussians=gmm.n_gaussians)
 
gmm1.load(HDF5File(filename, "r"))
 
assert type(gmm1.n_gaussians) is np.int64
 
assert type(gmm1.update_means) is np.bool_
 
assert type(gmm1.update_variances) is np.bool_
 
assert type(gmm1.update_weights) is np.bool_
 
assert type(gmm1.trainer) is str
 
assert gmm1.ubm is None
 
assert gmm == gmm1
 
with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
with tempfile.NamedTemporaryFile(suffix=".hdf5") as f:
filename = f.name
filename = f.name
gmm.save(filename)
gmm.save(filename)
@@ -923,7 +935,7 @@ def test_gmm_MAP_3():
@@ -923,7 +935,7 @@ def test_gmm_MAP_3():
update_variances=False,
update_variances=False,
update_weights=False,
update_weights=False,
mean_var_update_threshold=accuracy,
mean_var_update_threshold=accuracy,
relevance_factor=None,
map_relevance_factor=None,
)
)
gmm.variance_thresholds = threshold
gmm.variance_thresholds = threshold
@@ -1071,7 +1083,7 @@ def test_gmm_MAP_dask():
@@ -1071,7 +1083,7 @@ def test_gmm_MAP_dask():
update_variances=False,
update_variances=False,
update_weights=False,
update_weights=False,
mean_var_update_threshold=accuracy,
mean_var_update_threshold=accuracy,
relevance_factor=None,
map_relevance_factor=None,
)
)
gmm.variance_thresholds = threshold
gmm.variance_thresholds = threshold
Loading