Skip to content
Snippets Groups Projects

Add DatasetPipeline to work on xarray datasets

Merged Amir MOHAMMADI requested to merge xarray into master
Files
22
@@ -49,16 +49,10 @@ class DummyTransformer(TransformerMixin, BaseEstimator):
@@ -49,16 +49,10 @@ class DummyTransformer(TransformerMixin, BaseEstimator):
https://github.com/scikit-learn-contrib/project-
https://github.com/scikit-learn-contrib/project-
template/blob/master/skltemplate/_template.py."""
template/blob/master/skltemplate/_template.py."""
def __init__(self, picklable=True, i=None, **kwargs):
def __init__(self, i=None, **kwargs):
super().__init__(**kwargs)
super().__init__(**kwargs)
self.picklable = picklable
self.i = i
self.i = i
if not picklable:
import bob.core
self.rng = bob.core.random.mt19937()
def fit(self, X, y=None):
def fit(self, X, y=None):
return self
return self
@@ -221,7 +215,7 @@ def _build_estimator(path, i):
@@ -221,7 +215,7 @@ def _build_estimator(path, i):
return transformer
return transformer
def _build_transformer(path, i, picklable=True):
def _build_transformer(path, i):
features_dir = os.path.join(path, f"transformer{i}")
features_dir = os.path.join(path, f"transformer{i}")
estimator = mario.wrap(
estimator = mario.wrap(
@@ -288,7 +282,9 @@ def test_checkpoint_fit_transform_pipeline():
@@ -288,7 +282,9 @@ def test_checkpoint_fit_transform_pipeline():
transformer = ("1", _build_transformer(d, 1))
transformer = ("1", _build_transformer(d, 1))
pipeline = Pipeline([fitter, transformer])
pipeline = Pipeline([fitter, transformer])
if dask_enabled:
if dask_enabled:
pipeline = mario.wrap(["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1)
pipeline = mario.wrap(
 
["dask"], pipeline, fit_tag=[(1, "GPU")], npartitions=1
 
)
pipeline = pipeline.fit(samples)
pipeline = pipeline.fit(samples)
tags = mario.dask_tags(pipeline)
tags = mario.dask_tags(pipeline)
@@ -331,7 +327,7 @@ def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle():
@@ -331,7 +327,7 @@ def test_checkpoint_fit_transform_pipeline_with_dask_non_pickle():
fitter = ("0", _build_estimator(d, 0))
fitter = ("0", _build_estimator(d, 0))
transformer = (
transformer = (
"1",
"1",
_build_transformer(d, 1, picklable=False),
_build_transformer(d, 1),
)
)
pipeline = Pipeline([fitter, transformer])
pipeline = Pipeline([fitter, transformer])
Loading