Commit ac8afd2a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[xarray] Allow for multi argument transformers

parent 049b6410
......@@ -196,24 +196,31 @@ def _fit(*args, block):
class _TokenStableTransform:
def __init__(self, block, method_name=None, **kwargs):
def __init__(self, block, method_name=None, input_has_keys=False, **kwargs):
super().__init__(**kwargs)
self.block = block
self.method_name = method_name or "transform"
self.input_has_keys = input_has_keys
def __dask_tokenize__(self):
return (self.method_name, self.block.features_dir)
def __call__(self, *args, estimator):
data = args[0]
block, method_name = self.block, self.method_name
logger.info(f"Calling {block.estimator_name}.{method_name}")
features = getattr(estimator, self.method_name)(data)
input_args = args[:-1] if self.input_has_keys else args
try:
features = getattr(estimator, self.method_name)(*input_args)
except Exception as e:
raise RuntimeError(
f"Failed to transform data: {estimator}.{self.method_name}(*{input_args})"
) from e
# if keys are provided, checkpoint features
if len(args) == 2:
key = args[1]
if self.input_has_keys:
data = args[0]
key = args[-1]
l1, l2 = len(data), len(features)
if l1 != l2:
......@@ -300,7 +307,7 @@ def _blockwise_with_block_args(args, block, method_name=None):
return output_dim_name, new_axes, input_arg_pairs, dims, meta, output_shape
def _blockwise_with_block(args, block, method_name=None):
def _blockwise_with_block(args, block, method_name=None, input_has_keys=False):
(
output_dim_name,
new_axes,
......@@ -309,7 +316,9 @@ def _blockwise_with_block(args, block, method_name=None):
meta,
_,
) = _blockwise_with_block_args(args, block, method_name=None)
transform_func = _TokenStableTransform(block, method_name)
transform_func = _TokenStableTransform(
block, method_name, input_has_keys=input_has_keys
)
transform_func.__name__ = f"{block.estimator_name}.{method_name}"
data = dask.array.blockwise(
......@@ -356,7 +365,9 @@ def _transform_or_load(block, ds, input_columns, mn):
# compute non-saved data
if total_samples_n - saved_samples_n > 0:
args = _get_dask_args_from_ds(nonsaved_ds, input_columns)
dims, computed_data = _blockwise_with_block(args, block, mn)
dims, computed_data = _blockwise_with_block(
args, block, mn, input_has_keys=True
)
# load saved data
if saved_samples_n > 0:
......@@ -367,7 +378,10 @@ def _transform_or_load(block, ds, input_columns, mn):
dims, meta, shape = _blockwise_with_block_args(args, block, mn)[-3:]
loaded_data = [
dask.array.from_delayed(
dask.delayed(block.load)(k), shape=shape[1:], meta=meta, name=False,
dask.delayed(block.load)(k),
shape=shape[1:],
meta=meta,
name=False,
)[None, ...]
for k in key[saved_samples]
]
......@@ -414,7 +428,10 @@ class DatasetPipeline(_BaseComposition):
def _transform(self, ds, do_fit=False, method_name=None):
for i, block in enumerate(self.graph):
if block.dataset_map is not None:
ds = block.dataset_map(ds)
try:
ds = block.dataset_map(ds)
except Exception as e:
raise RuntimeError(f"Could not map ds {ds}\n with {block.dataset_map}") from e
continue
if do_fit:
......@@ -433,7 +450,10 @@ class DatasetPipeline(_BaseComposition):
block.estimator_ = _fit(*args, block=block)
else:
_fit.__name__ = f"{block.estimator_name}.fit"
block.estimator_ = dask.delayed(_fit)(*args, block=block,)
block.estimator_ = dask.delayed(_fit)(
*args,
block=block,
)
mn = "transform"
if i == len(self.graph) - 1:
......@@ -443,7 +463,9 @@ class DatasetPipeline(_BaseComposition):
if block.features_dir is None:
args = _get_dask_args_from_ds(ds, block.transform_input)
dims, data = _blockwise_with_block(args, block, mn)
dims, data = _blockwise_with_block(
args, block, mn, input_has_keys=False
)
else:
dims, data = _transform_or_load(block, ds, block.transform_input, mn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment