Commit bd7427e3 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Better names in dask graph for FunctionTransformer

parent 6788926e
Pipeline #45298 passed with stage
in 3 minutes and 55 seconds
......@@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator
from sklearn.base import MetaEstimatorMixin
from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from .sample import DelayedSample
from .sample import SampleBatch
......@@ -31,7 +32,15 @@ def _frmt(estimator, limit=30):
while hasattr(estimator, "estimator"):
name += f"{_n(estimator)}|"
estimator = estimator.estimator
if (
isinstance(estimator, FunctionTransformer)
and type(estimator) is FunctionTransformer
):
name += str(estimator.func.__name__)
else:
name += str(estimator)
name = f"{name:.{limit}}"
return name
......@@ -128,7 +137,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset,
self._samples_transform(sset.samples, method_name),
parent=sset,
)
for sset in samples
]
......@@ -366,7 +376,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......@@ -418,7 +432,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,)
self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self.fit_tag
......@@ -511,7 +528,9 @@ def wrap(bases, estimator=None, **kwargs):
# when checkpointing a pipeline, checkpoint each transformer in its own folder
new_kwargs = dict(kwargs)
features_dir, model_path = kwargs.get("features_dir"), kwargs.get("model_path")
features_dir, model_path = kwargs.get("features_dir"), kwargs.get(
"model_path"
)
if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name)
if model_path is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment