Skip to content
Snippets Groups Projects

Fix the initialization of GMMMachine

Merged Yannick DAYER requested to merge fix-init into master
2 unresolved threads
4 files
+ 13
3
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 6
1
@@ -183,7 +183,9 @@ class GMMStats:
@@ -183,7 +183,9 @@ class GMMStats:
Second order statistic
Second order statistic
"""
"""
def __init__(self, n_gaussians: int, n_features: int) -> None:
def __init__(self, n_gaussians: int, n_features: int, **kwargs) -> None:
 
super().__init__(**kwargs)
 
self.n_gaussians = n_gaussians
self.n_gaussians = n_gaussians
self.n_features = n_features
self.n_features = n_features
self.log_likelihood = 0
self.log_likelihood = 0
@@ -420,6 +422,7 @@ class GMMMachine(BaseEstimator):
@@ -420,6 +422,7 @@ class GMMMachine(BaseEstimator):
mean_var_update_threshold: float = EPSILON,
mean_var_update_threshold: float = EPSILON,
map_alpha: float = 0.5,
map_alpha: float = 0.5,
map_relevance_factor: Union[None, float] = 4,
map_relevance_factor: Union[None, float] = 4,
 
**kwargs,
):
):
"""
"""
Parameters
Parameters
@@ -459,6 +462,8 @@ class GMMMachine(BaseEstimator):
@@ -459,6 +462,8 @@ class GMMMachine(BaseEstimator):
`trainer == "map"`)
`trainer == "map"`)
"""
"""
 
super().__init__(**kwargs)
 
self.n_gaussians = n_gaussians
self.n_gaussians = n_gaussians
self.trainer = trainer if trainer in ["ml", "map"] else "ml"
self.trainer = trainer if trainer in ["ml", "map"] else "ml"
self.m_step_func = (
self.m_step_func = (
Loading