Commit bd7427e3 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Better names in dask graph for FunctionTransformer

parent 6788926e
Pipeline #45298 passed with stage
in 3 minutes and 55 seconds
...@@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator ...@@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator
from sklearn.base import MetaEstimatorMixin from sklearn.base import MetaEstimatorMixin
from sklearn.base import TransformerMixin from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from .sample import DelayedSample from .sample import DelayedSample
from .sample import SampleBatch from .sample import SampleBatch
...@@ -31,7 +32,15 @@ def _frmt(estimator, limit=30): ...@@ -31,7 +32,15 @@ def _frmt(estimator, limit=30):
while hasattr(estimator, "estimator"): while hasattr(estimator, "estimator"):
name += f"{_n(estimator)}|" name += f"{_n(estimator)}|"
estimator = estimator.estimator estimator = estimator.estimator
name += str(estimator)
if (
isinstance(estimator, FunctionTransformer)
and type(estimator) is FunctionTransformer
):
name += str(estimator.func.__name__)
else:
name += str(estimator)
name = f"{name:.{limit}}" name = f"{name:.{limit}}"
return name return name
...@@ -128,7 +137,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin): ...@@ -128,7 +137,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet): if isinstance(samples[0], SampleSet):
return [ return [
SampleSet( SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset, self._samples_transform(sset.samples, method_name),
parent=sset,
) )
for sset in samples for sset in samples
] ]
...@@ -366,7 +376,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin): ...@@ -366,7 +376,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
""" """
def __init__( def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs, self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.estimator = estimator self.estimator = estimator
...@@ -418,7 +432,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin): ...@@ -418,7 +432,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs # change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit" _fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,) self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None: if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self.fit_tag self.resource_tags[self._dask_state] = self.fit_tag
...@@ -511,7 +528,9 @@ def wrap(bases, estimator=None, **kwargs): ...@@ -511,7 +528,9 @@ def wrap(bases, estimator=None, **kwargs):
# when checkpointing a pipeline, checkpoint each transformer in its own folder # when checkpointing a pipeline, checkpoint each transformer in its own folder
new_kwargs = dict(kwargs) new_kwargs = dict(kwargs)
features_dir, model_path = kwargs.get("features_dir"), kwargs.get("model_path") features_dir, model_path = kwargs.get("features_dir"), kwargs.get(
"model_path"
)
if features_dir is not None: if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name) new_kwargs["features_dir"] = os.path.join(features_dir, name)
if model_path is not None: if model_path is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment