Skip to content

Mixin classes for sklearn estimators should not have an __init__ method

I was thinking that we can get away with this but apparently we cannot. Because of the way that BaseEstimator handles params, providing an extra __init__ method in mixins will break the estimator.

Here is an example:

In [2]: from sklearn.svm import SVC                                                                                                                                                                                                            
   ...: from bob.pipelines.mixins import CheckpointMixin, SampleMixin                                                                                                                                                                          
   ...: class CheckpointSampleSVC(CheckpointMixin, SampleMixin, SVC):                                                                                                                                                                          
   ...:     pass                                                                                                                                                                                                                               
   ...:                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                               
In [8]: original_estimator = SVC()

In [9]: original_estimator
Out[9]: 
SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

In [10]: original_estimator.set_params(C=2)
Out[10]: 
SVC(C=2, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

In [11]: checkpointing_sample_estimator = CheckpointSampleSVC()

In [12]: checkpointing_sample_estimator
Out[12]: 
CheckpointSampleSVC(extension='.h5', features_dir=None,
                    load_func=<function load at 0x7f1ce85e5290>,
                    model_path=None,
                    save_func=<function save at 0x7f1ce85e53b0>)

In [13]: checkpointing_sample_estimator.set_params(C=2)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-bbed69696a06> in <module>
----> 1 checkpointing_sample_estimator.set_params(C=2)

conda/envs/dask/lib/python3.7/site-packages/sklearn/base.py in set_params(self, **params)
    234                                  'Check the list of available parameters '
    235                                  'with `estimator.get_params().keys()`.' %
--> 236                                  (key, self))
    237
    238             if delim:

ValueError: Invalid parameter C for estimator CheckpointSampleSVC(extension='.h5', features_dir=None,
                    load_func=<function load at 0x7f1ce85e5290>,
                    model_path=None,
                    save_func=<function save at 0x7f1ce85e53b0>). Check the list of available parameters with `estimator.get_params().keys()`.

set_params is important because it is used in classes like https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV

Edited by Amir MOHAMMADI