Mixin classes for sklearn estimators should not have an __init__ method
I was thinking that we can get away with this but apparently we cannot. Because of the way that BaseEstimator handles params, providing an extra __init__
method in mixins will break the estimator.
Here is an example:
In [2]: from sklearn.svm import SVC
...: from bob.pipelines.mixins import CheckpointMixin, SampleMixin
...: class CheckpointSampleSVC(CheckpointMixin, SampleMixin, SVC):
...: pass
...:
In [8]: original_estimator = SVC()
In [9]: original_estimator
Out[9]:
SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
In [10]: original_estimator.set_params(C=2)
Out[10]:
SVC(C=2, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
max_iter=-1, probability=False, random_state=None, shrinking=True,
tol=0.001, verbose=False)
In [11]: checkpointing_sample_estimator = CheckpointSampleSVC()
In [12]: checkpointing_sample_estimator
Out[12]:
CheckpointSampleSVC(extension='.h5', features_dir=None,
load_func=<function load at 0x7f1ce85e5290>,
model_path=None,
save_func=<function save at 0x7f1ce85e53b0>)
In [13]: checkpointing_sample_estimator.set_params(C=2)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-13-bbed69696a06> in <module>
----> 1 checkpointing_sample_estimator.set_params(C=2)
conda/envs/dask/lib/python3.7/site-packages/sklearn/base.py in set_params(self, **params)
234 'Check the list of available parameters '
235 'with `estimator.get_params().keys()`.' %
--> 236 (key, self))
237
238 if delim:
ValueError: Invalid parameter C for estimator CheckpointSampleSVC(extension='.h5', features_dir=None,
load_func=<function load at 0x7f1ce85e5290>,
model_path=None,
save_func=<function save at 0x7f1ce85e53b0>). Check the list of available parameters with `estimator.get_params().keys()`.
set_params
is important because it is used in classes like https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV
Edited by Amir MOHAMMADI