Checking if the super class has the method fit to avoid unecessary stack from the base class

parent 9fbcf48e
Pipeline #38302 passed with stage
in 4 minutes and 11 seconds
......@@ -12,7 +12,7 @@ from sklearn.pipeline import Pipeline
from dask import delayed
import dask.bag
def dask_it(o, fit_tag=None, transform_tag=None):
def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
"""
Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with
:py:class`DaskEstimatorMixin`
......@@ -76,7 +76,7 @@ def dask_it(o, fit_tag=None, transform_tag=None):
if isinstance(o, Pipeline):
#Adding a daskbag in the tail of the pipeline
o.steps.insert(0, ('0', DaskBagMixin()))
o.steps.insert(0, ('0', DaskBagMixin(npartitions=npartitions)))
# Patching dask_resources
dasked = mix_me_up(DaskEstimatorMixin, o)
......@@ -184,9 +184,6 @@ class SampleMixin:
Also implement ``predict``, ``predict_proba``, and ``score``. See:
https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects
.. todo::
Allow handling the targets given to the ``fit`` method.
"""
def transform(self, samples):
......@@ -206,7 +203,13 @@ class SampleMixin:
def fit(self, samples, y=None):
return super().fit([s.data for s in samples], y=y)
# IF THE SUPER METHOD IS NOT FITTABLE,
# THERE'S NO REASON TO STACK THOSE SAMPLES
if( hasattr(super(), "fit")):
return super().fit([s.data for s in samples], y=y)
return self
class CheckpointMixin:
......@@ -251,13 +254,21 @@ class CheckpointMixin:
def fit(self, samples, y=None):
# IF THE SUPER METHOD IS NOT FITTABLE,
# THERE'S NO REASON TO STACK THOSE SAMPLES
if( not hasattr(super(), "fit") ):
return self
if self.model_path is not None and os.path.isfile(self.model_path):
return self.load_model()
super().fit(samples, y=y)
return self.save_model()
def fit_transform(self, samples, y=None):
return self.fit(samples, y=y).transform(samples)
def make_path(self, sample):
......@@ -403,7 +414,7 @@ class DaskEstimatorMixin:
return self
def transform(self, X):
def _transf(X_line, dask_state):
def _transf(X_line, dask_state):
return super(DaskEstimatorMixin, dask_state).transform(X_line)
map_partitions = X.map_partitions(_transf, self._dask_state)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment