diff --git a/bob/learn/em/gmm.py b/bob/learn/em/gmm.py index 1a439e142cc16eed28993f10c5db16c0a2ff3084..d9ae87b4111e9bac008a95ee8d8869d003242fc8 100644 --- a/bob/learn/em/gmm.py +++ b/bob/learn/em/gmm.py @@ -183,7 +183,9 @@ class GMMStats: Second order statistic """ - def __init__(self, n_gaussians: int, n_features: int) -> None: + def __init__(self, n_gaussians: int, n_features: int, **kwargs) -> None: + super().__init__(**kwargs) + self.n_gaussians = n_gaussians self.n_features = n_features self.log_likelihood = 0 diff --git a/bob/learn/em/kmeans.py b/bob/learn/em/kmeans.py index 61af76ec9a54dce65c43d7439ecbbb8a78b61031..6cd29e4cf33de36a7792702cf383a5fb2c62666c 100644 --- a/bob/learn/em/kmeans.py +++ b/bob/learn/em/kmeans.py @@ -202,6 +202,7 @@ class KMeansMachine(BaseEstimator): random_state: Union[int, np.random.RandomState] = 0, init_max_iter: Union[int, None] = 5, oversampling_factor: float = 2, + **kwargs, ) -> None: """ Parameters @@ -219,6 +220,8 @@ class KMeansMachine(BaseEstimator): The maximum number of iterations for the initialization part. """ + super().__init__(**kwargs) + if n_clusters < 1: raise ValueError("The Number of cluster should be greater thant 0.") self.n_clusters = n_clusters diff --git a/bob/learn/em/wccn.py b/bob/learn/em/wccn.py index 47d5b80463ebb74e0c1e7f8b8dccc1f52ed31529..32a4e80788c135a5b5a8c5c07d3c5fc1abc59eaf 100644 --- a/bob/learn/em/wccn.py +++ b/bob/learn/em/wccn.py @@ -34,7 +34,8 @@ class WCCN(TransformerMixin, BaseEstimator): """ - def __init__(self, pinv=False): + def __init__(self, pinv=False, **kwargs): + super().__init__(**kwargs) self.pinv = pinv def fit(self, X, y): diff --git a/bob/learn/em/whitening.py b/bob/learn/em/whitening.py index bf81a36203580120e2d3f90664bed38401809d58..4b78da8dbd3c3a2fa75bade6a31dd4d11c535938 100644 --- a/bob/learn/em/whitening.py +++ b/bob/learn/em/whitening.py @@ -41,7 +41,8 @@ class Whitening(TransformerMixin, BaseEstimator): """ - def __init__(self, pinv: bool = False): + def __init__(self, pinv: bool = False, **kwargs): + super().__init__(**kwargs) self.pinv = pinv def fit(self, X, y=None):