Commit a4ee858b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'dask-pipelines' into 'master'

Multiple Changes

See merge request !45
parents 324038de bd7427e3
Pipeline #45521 passed with stages
in 11 minutes and 55 seconds
......@@ -6,6 +6,7 @@ from .sample import DelayedSample
from .sample import DelayedSampleSet
from .sample import Sample
from .sample import SampleSet
from .sample import SampleBatch
from .sample import hdf5_to_sample # noqa
from .sample import sample_to_hdf5 # noqa
from .wrappers import BaseWrapper
......@@ -41,6 +42,7 @@ __appropriate__(
DelayedSample,
SampleSet,
DelayedSampleSet,
SampleBatch,
BaseWrapper,
DelayedSamplesCall,
SampleWrapper,
......
......@@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator
from sklearn.base import MetaEstimatorMixin
from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from .sample import DelayedSample
from .sample import SampleBatch
......@@ -31,7 +32,15 @@ def _frmt(estimator, limit=30):
while hasattr(estimator, "estimator"):
name += f"{_n(estimator)}|"
estimator = estimator.estimator
name += str(estimator)
if (
isinstance(estimator, FunctionTransformer)
and type(estimator) is FunctionTransformer
):
name += str(estimator.func.__name__)
else:
name += str(estimator)
name = f"{name:.{limit}}"
return name
......@@ -128,7 +137,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset,
self._samples_transform(sset.samples, method_name),
parent=sset,
)
for sset in samples
]
......@@ -312,7 +322,10 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
os.makedirs(os.path.dirname(path), exist_ok=True)
# Gets sample.data or sample.<sample_attribute> if specified
to_save = getattr(sample, self.sample_attribute)
return self.save_func(to_save, path)
try:
self.save_func(to_save, path)
except Exception as e:
raise RuntimeError(f"Could not save {to_save} duing {self}.save") from e
def load(self, sample, path):
# because we are checkpointing, we return a DelayedSample
......@@ -363,7 +376,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......@@ -415,7 +432,10 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,)
self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self.fit_tag
......@@ -505,7 +525,18 @@ def wrap(bases, estimator=None, **kwargs):
if isinstance(estimator, Pipeline):
# wrap inner steps
for idx, name, trans in estimator._iter():
trans, leftover = _wrap(trans, **kwargs)
# when checkpointing a pipeline, checkpoint each transformer in its own folder
new_kwargs = dict(kwargs)
features_dir, model_path = kwargs.get("features_dir"), kwargs.get(
"model_path"
)
if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name)
if model_path is not None:
new_kwargs["model_path"] = os.path.join(model_path, name)
trans, leftover = _wrap(trans, **new_kwargs)
estimator.steps[idx] = (name, trans)
# if being wrapped with DaskWrapper, add ToDaskBag to the steps
......
......@@ -196,24 +196,31 @@ def _fit(*args, block):
class _TokenStableTransform:
def __init__(self, block, method_name=None, **kwargs):
def __init__(self, block, method_name=None, input_has_keys=False, **kwargs):
super().__init__(**kwargs)
self.block = block
self.method_name = method_name or "transform"
self.input_has_keys = input_has_keys
def __dask_tokenize__(self):
return (self.method_name, self.block.features_dir)
def __call__(self, *args, estimator):
data = args[0]
block, method_name = self.block, self.method_name
logger.info(f"Calling {block.estimator_name}.{method_name}")
features = getattr(estimator, self.method_name)(data)
input_args = args[:-1] if self.input_has_keys else args
try:
features = getattr(estimator, self.method_name)(*input_args)
except Exception as e:
raise RuntimeError(
f"Failed to transform data: {estimator}.{self.method_name}(*{input_args})"
) from e
# if keys are provided, checkpoint features
if len(args) == 2:
key = args[1]
if self.input_has_keys:
data = args[0]
key = args[-1]
l1, l2 = len(data), len(features)
if l1 != l2:
......@@ -300,7 +307,7 @@ def _blockwise_with_block_args(args, block, method_name=None):
return output_dim_name, new_axes, input_arg_pairs, dims, meta, output_shape
def _blockwise_with_block(args, block, method_name=None):
def _blockwise_with_block(args, block, method_name=None, input_has_keys=False):
(
output_dim_name,
new_axes,
......@@ -309,7 +316,9 @@ def _blockwise_with_block(args, block, method_name=None):
meta,
_,
) = _blockwise_with_block_args(args, block, method_name=None)
transform_func = _TokenStableTransform(block, method_name)
transform_func = _TokenStableTransform(
block, method_name, input_has_keys=input_has_keys
)
transform_func.__name__ = f"{block.estimator_name}.{method_name}"
data = dask.array.blockwise(
......@@ -356,7 +365,9 @@ def _transform_or_load(block, ds, input_columns, mn):
# compute non-saved data
if total_samples_n - saved_samples_n > 0:
args = _get_dask_args_from_ds(nonsaved_ds, input_columns)
dims, computed_data = _blockwise_with_block(args, block, mn)
dims, computed_data = _blockwise_with_block(
args, block, mn, input_has_keys=True
)
# load saved data
if saved_samples_n > 0:
......@@ -367,7 +378,10 @@ def _transform_or_load(block, ds, input_columns, mn):
dims, meta, shape = _blockwise_with_block_args(args, block, mn)[-3:]
loaded_data = [
dask.array.from_delayed(
dask.delayed(block.load)(k), shape=shape[1:], meta=meta, name=False,
dask.delayed(block.load)(k),
shape=shape[1:],
meta=meta,
name=False,
)[None, ...]
for k in key[saved_samples]
]
......@@ -414,7 +428,10 @@ class DatasetPipeline(_BaseComposition):
def _transform(self, ds, do_fit=False, method_name=None):
for i, block in enumerate(self.graph):
if block.dataset_map is not None:
ds = block.dataset_map(ds)
try:
ds = block.dataset_map(ds)
except Exception as e:
raise RuntimeError(f"Could not map ds {ds}\n with {block.dataset_map}") from e
continue
if do_fit:
......@@ -433,7 +450,10 @@ class DatasetPipeline(_BaseComposition):
block.estimator_ = _fit(*args, block=block)
else:
_fit.__name__ = f"{block.estimator_name}.fit"
block.estimator_ = dask.delayed(_fit)(*args, block=block,)
block.estimator_ = dask.delayed(_fit)(
*args,
block=block,
)
mn = "transform"
if i == len(self.graph) - 1:
......@@ -443,7 +463,9 @@ class DatasetPipeline(_BaseComposition):
if block.features_dir is None:
args = _get_dask_args_from_ds(ds, block.transform_input)
dims, data = _blockwise_with_block(args, block, mn)
dims, data = _blockwise_with_block(
args, block, mn, input_has_keys=False
)
else:
dims, data = _transform_or_load(block, ds, block.transform_input, mn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment