From 71ad94f7461624979be0a719ba859b4fcf46b1cf Mon Sep 17 00:00:00 2001 From: Yannick DAYER Date: Tue, 3 May 2022 19:57:26 +0200 Subject: [PATCH 1/3] Add super().__init__() to GMMMachine Allows inheriting classes to call super().__init__() on multiple parents --- bob/learn/em/gmm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 988ef3a..06f3850 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -420,6 +420,7 @@ class GMMMachine(BaseEstimator): mean_var_update_threshold: float = EPSILON, map_alpha: float = 0.5, map_relevance_factor: Union[None, float] = 4, + **kwargs, ): """ Parameters @@ -505,6 +506,7 @@ class GMMMachine(BaseEstimator): self.weights = weights self.map_alpha = map_alpha self.map_relevance_factor = map_relevance_factor + super().__init__(**kwargs) @property def weights(self): -- GitLab From aeee7f1e3370552bab2cf49e713c4d00f48452d2 Mon Sep 17 00:00:00 2001 From: Yannick DAYER Date: Tue, 3 May 2022 20:06:18 +0200 Subject: [PATCH 2/3] Change the placement of super().__init__() --- bob/learn/em/gmm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 06f3850..1a439e1 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -460,6 +460,8 @@ class GMMMachine(BaseEstimator): `trainer == "map"`) """ + super().__init__(**kwargs) + self.n_gaussians = n_gaussians self.trainer = trainer if trainer in ["ml", "map"] else "ml" self.m_step_func = ( @@ -506,7 +508,6 @@ class GMMMachine(BaseEstimator): self.weights = weights self.map_alpha = map_alpha self.map_relevance_factor = map_relevance_factor - super().__init__(**kwargs) @property def weights(self): -- GitLab From e15df5dda638c13b42b85bdb6970589376a43bb4 Mon Sep 17 00:00:00 2001 From: Yannick DAYER Date: Tue, 3 May 2022 20:14:15 +0200 Subject: [PATCH 3/3] Add super().__init__() to classes --- bob/learn/em/gmm.py | 4 +++- bob/learn/em/kmeans.py | 3 +++ bob/learn/em/wccn.py | 3 ++- bob/learn/em/whitening.py | 3 ++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 1a439e1..d9ae87b 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -183,7 +183,9 @@ class GMMStats: Second order statistic """ - def __init__(self, n_gaussians: int, n_features: int) -> None: + def __init__(self, n_gaussians: int, n_features: int, **kwargs) -> None: + super().__init__(**kwargs) + self.n_gaussians = n_gaussians self.n_features = n_features self.log_likelihood = 0 diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index 61af76e..6cd29e4 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -202,6 +202,7 @@ class KMeansMachine(BaseEstimator): random_state: Union[int, np.random.RandomState] = 0, init_max_iter: Union[int, None] = 5, oversampling_factor: float = 2, + **kwargs, ) -> None: """ Parameters @@ -219,6 +220,8 @@ class KMeansMachine(BaseEstimator): The maximum number of iterations for the initialization part. """ + super().__init__(**kwargs) + if n_clusters < 1: raise ValueError("The Number of cluster should be greater thant 0.") self.n_clusters = n_clusters diff --git a/bob/learn/em/wccn.py b/bob/learn/em/wccn.py index 47d5b80..32a4e80 100644 --- a/bob/learn/em/wccn.py +++ b/bob/learn/em/wccn.py @@ -34,7 +34,8 @@ class WCCN(TransformerMixin, BaseEstimator): """ - def __init__(self, pinv=False): + def __init__(self, pinv=False, **kwargs): + super().__init__(**kwargs) self.pinv = pinv def fit(self, X, y): diff --git a/bob/learn/em/whitening.py b/bob/learn/em/whitening.py index bf81a36..4b78da8 100644 --- a/bob/learn/em/whitening.py +++ b/bob/learn/em/whitening.py @@ -41,7 +41,8 @@ class Whitening(TransformerMixin, BaseEstimator): """ - def __init__(self, pinv: bool = False): + def __init__(self, pinv: bool = False, **kwargs): + super().__init__(**kwargs) self.pinv = pinv def fit(self, X, y=None): -- GitLab