From e15df5dda638c13b42b85bdb6970589376a43bb4 Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Tue, 3 May 2022 20:14:15 +0200 Subject: [PATCH] Add super().__init__() to classes --- bob/learn/em/gmm.py | 4 +++- bob/learn/em/kmeans.py | 3 +++ bob/learn/em/wccn.py | 3 ++- bob/learn/em/whitening.py | 3 ++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 1a439e1..d9ae87b 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -183,7 +183,9 @@ class GMMStats: Second order statistic """ - def __init__(self, n_gaussians: int, n_features: int) -> None: + def __init__(self, n_gaussians: int, n_features: int, **kwargs) -> None: + super().__init__(**kwargs) + self.n_gaussians = n_gaussians self.n_features = n_features self.log_likelihood = 0 diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index 61af76e..6cd29e4 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -202,6 +202,7 @@ class KMeansMachine(BaseEstimator): random_state: Union[int, np.random.RandomState] = 0, init_max_iter: Union[int, None] = 5, oversampling_factor: float = 2, + **kwargs, ) -> None: """ Parameters @@ -219,6 +220,8 @@ class KMeansMachine(BaseEstimator): The maximum number of iterations for the initialization part. """ + super().__init__(**kwargs) + if n_clusters < 1: raise ValueError("The Number of cluster should be greater thant 0.") self.n_clusters = n_clusters diff --git a/bob/learn/em/wccn.py b/bob/learn/em/wccn.py index 47d5b80..32a4e80 100644 --- a/bob/learn/em/wccn.py +++ b/bob/learn/em/wccn.py @@ -34,7 +34,8 @@ class WCCN(TransformerMixin, BaseEstimator): """ - def __init__(self, pinv=False): + def __init__(self, pinv=False, **kwargs): + super().__init__(**kwargs) self.pinv = pinv def fit(self, X, y): diff --git a/bob/learn/em/whitening.py b/bob/learn/em/whitening.py index bf81a36..4b78da8 100644 --- a/bob/learn/em/whitening.py +++ b/bob/learn/em/whitening.py @@ -41,7 +41,8 @@ class Whitening(TransformerMixin, BaseEstimator): """ - def __init__(self, pinv: bool = False): + def __init__(self, pinv: bool = False, **kwargs): + super().__init__(**kwargs) self.pinv = pinv def fit(self, X, y=None): -- GitLab