Skip to content
GitLab
Projects Groups Snippets
  • /
  • Help
    • Help
    • Support
    • Community forum
    • Submit feedback
    • Contribute to GitLab
  • Sign in
  • bob.pipelines bob.pipelines
  • Project information
    • Project information
    • Activity
    • Labels
    • Members
  • Repository
    • Repository
    • Files
    • Commits
    • Branches
    • Tags
    • Contributors
    • Graph
    • Compare
  • Issues 5
    • Issues 5
    • List
    • Boards
    • Service Desk
    • Milestones
  • Merge requests 1
    • Merge requests 1
  • CI/CD
    • CI/CD
    • Pipelines
    • Jobs
    • Schedules
  • Deployments
    • Deployments
    • Environments
    • Releases
  • Packages and registries
    • Packages and registries
    • Package Registry
    • Infrastructure Registry
  • Monitor
    • Monitor
    • Incidents
  • Analytics
    • Analytics
    • Value stream
    • CI/CD
    • Repository
  • Activity
  • Graph
  • Create a new issue
  • Jobs
  • Commits
  • Issue Boards
Collapse sidebar
  • bobbob
  • bob.pipelinesbob.pipelines
  • Issues
  • #11
Closed
Open
Issue created Apr 22, 2020 by Amir MOHAMMADI@amohammadiOwner

Mixin classes for sklearn estimators should not have an __init__ method

I was thinking that we can get away with this but apparently we cannot. Because of the way that BaseEstimator handles params, providing an extra __init__ method in mixins will break the estimator.

Here is an example:

In [2]: from sklearn.svm import SVC                                                                                                                                                                                                            
   ...: from bob.pipelines.mixins import CheckpointMixin, SampleMixin                                                                                                                                                                          
   ...: class CheckpointSampleSVC(CheckpointMixin, SampleMixin, SVC):                                                                                                                                                                          
   ...:     pass                                                                                                                                                                                                                               
   ...:                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                               
In [8]: original_estimator = SVC()

In [9]: original_estimator
Out[9]: 
SVC(C=1.0, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

In [10]: original_estimator.set_params(C=2)
Out[10]: 
SVC(C=2, break_ties=False, cache_size=200, class_weight=None, coef0=0.0,
    decision_function_shape='ovr', degree=3, gamma='scale', kernel='rbf',
    max_iter=-1, probability=False, random_state=None, shrinking=True,
    tol=0.001, verbose=False)

In [11]: checkpointing_sample_estimator = CheckpointSampleSVC()

In [12]: checkpointing_sample_estimator
Out[12]: 
CheckpointSampleSVC(extension='.h5', features_dir=None,
                    load_func=<function load at 0x7f1ce85e5290>,
                    model_path=None,
                    save_func=<function save at 0x7f1ce85e53b0>)

In [13]: checkpointing_sample_estimator.set_params(C=2)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-bbed69696a06> in <module>
----> 1 checkpointing_sample_estimator.set_params(C=2)

conda/envs/dask/lib/python3.7/site-packages/sklearn/base.py in set_params(self, **params)
    234                                  'Check the list of available parameters '
    235                                  'with `estimator.get_params().keys()`.' %
--> 236                                  (key, self))
    237
    238             if delim:

ValueError: Invalid parameter C for estimator CheckpointSampleSVC(extension='.h5', features_dir=None,
                    load_func=<function load at 0x7f1ce85e5290>,
                    model_path=None,
                    save_func=<function save at 0x7f1ce85e53b0>). Check the list of available parameters with `estimator.get_params().keys()`.

set_params is important because it is used in classes like https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV

Edited Apr 22, 2020 by Amir MOHAMMADI
Assignee
Assign to
Time tracking