Skip to content
Snippets Groups Projects
Commit 96076c98 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Checking if the super class has the method fit to avoid unecessary stack from the base class

parent 9fbcf48e
No related branches found
No related tags found
1 merge request!8Checking if the super class has the method fit to avoid unecessary stack from the base class
Pipeline #38302 passed
...@@ -12,7 +12,7 @@ from sklearn.pipeline import Pipeline ...@@ -12,7 +12,7 @@ from sklearn.pipeline import Pipeline
from dask import delayed from dask import delayed
import dask.bag import dask.bag
def dask_it(o, fit_tag=None, transform_tag=None): def dask_it(o, fit_tag=None, transform_tag=None, npartitions=None):
""" """
Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with Mix up any :py:class:`sklearn.pipeline.Pipeline` or :py:class:`sklearn.estimator.Base` with
:py:class`DaskEstimatorMixin` :py:class`DaskEstimatorMixin`
...@@ -76,7 +76,7 @@ def dask_it(o, fit_tag=None, transform_tag=None): ...@@ -76,7 +76,7 @@ def dask_it(o, fit_tag=None, transform_tag=None):
if isinstance(o, Pipeline): if isinstance(o, Pipeline):
#Adding a daskbag in the tail of the pipeline #Adding a daskbag in the tail of the pipeline
o.steps.insert(0, ('0', DaskBagMixin())) o.steps.insert(0, ('0', DaskBagMixin(npartitions=npartitions)))
# Patching dask_resources # Patching dask_resources
dasked = mix_me_up(DaskEstimatorMixin, o) dasked = mix_me_up(DaskEstimatorMixin, o)
...@@ -184,9 +184,6 @@ class SampleMixin: ...@@ -184,9 +184,6 @@ class SampleMixin:
Also implement ``predict``, ``predict_proba``, and ``score``. See: Also implement ``predict``, ``predict_proba``, and ``score``. See:
https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects
.. todo::
Allow handling the targets given to the ``fit`` method.
""" """
def transform(self, samples): def transform(self, samples):
...@@ -206,7 +203,13 @@ class SampleMixin: ...@@ -206,7 +203,13 @@ class SampleMixin:
def fit(self, samples, y=None): def fit(self, samples, y=None):
return super().fit([s.data for s in samples], y=y)
# IF THE SUPER METHOD IS NOT FITTABLE,
# THERE'S NO REASON TO STACK THOSE SAMPLES
if( hasattr(super(), "fit")):
return super().fit([s.data for s in samples], y=y)
return self
class CheckpointMixin: class CheckpointMixin:
...@@ -251,13 +254,21 @@ class CheckpointMixin: ...@@ -251,13 +254,21 @@ class CheckpointMixin:
def fit(self, samples, y=None): def fit(self, samples, y=None):
# IF THE SUPER METHOD IS NOT FITTABLE,
# THERE'S NO REASON TO STACK THOSE SAMPLES
if( not hasattr(super(), "fit") ):
return self
if self.model_path is not None and os.path.isfile(self.model_path): if self.model_path is not None and os.path.isfile(self.model_path):
return self.load_model() return self.load_model()
super().fit(samples, y=y) super().fit(samples, y=y)
return self.save_model() return self.save_model()
def fit_transform(self, samples, y=None): def fit_transform(self, samples, y=None):
return self.fit(samples, y=y).transform(samples) return self.fit(samples, y=y).transform(samples)
def make_path(self, sample): def make_path(self, sample):
...@@ -403,7 +414,7 @@ class DaskEstimatorMixin: ...@@ -403,7 +414,7 @@ class DaskEstimatorMixin:
return self return self
def transform(self, X): def transform(self, X):
def _transf(X_line, dask_state): def _transf(X_line, dask_state):
return super(DaskEstimatorMixin, dask_state).transform(X_line) return super(DaskEstimatorMixin, dask_state).transform(X_line)
map_partitions = X.map_partitions(_transf, self._dask_state) map_partitions = X.map_partitions(_transf, self._dask_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment