Commit b652eb14 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[wrappers] improve error handling

parent 0e43d548
......@@ -138,7 +138,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset,
self._samples_transform(sset.samples, method_name),
parent=sset,
)
for sset in samples
]
......@@ -407,7 +408,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......@@ -458,13 +463,22 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
logger.debug(f"{_frmt(self)}.fit")
def _fit(X, y, **fit_params):
self.estimator = self.estimator.fit(X, y, **fit_params)
try:
self.estimator = self.estimator.fit(X, y, **fit_params)
except Exception as e:
raise RuntimeError(
f"Something went wrong when fitting {self.estimator} "
f"from {self}"
) from e
copy_learned_attributes(self.estimator, self)
return self.estimator
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,)
self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self._make_dask_resource_tag(
......@@ -582,7 +596,7 @@ def wrap(bases, estimator=None, **kwargs):
estimator, leftover = _wrap(estimator, **kwargs)
if leftover:
raise ValueError(f"Got extra kwargs that were not consumed: {kwargs}")
raise ValueError(f"Got extra kwargs that were not consumed: {leftover}")
return estimator
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment