Skip to content
Snippets Groups Projects
Commit b652eb14 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[wrappers] improve error handling

parent 0e43d548
No related branches found
No related tags found
1 merge request!52[CheckpointWrapper] Allow custom save and load functions through estimator tags
...@@ -138,7 +138,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin): ...@@ -138,7 +138,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet): if isinstance(samples[0], SampleSet):
return [ return [
SampleSet( SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset, self._samples_transform(sset.samples, method_name),
parent=sset,
) )
for sset in samples for sset in samples
] ]
...@@ -407,7 +408,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin): ...@@ -407,7 +408,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
""" """
def __init__( def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs, self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.estimator = estimator self.estimator = estimator
...@@ -458,13 +463,22 @@ class DaskWrapper(BaseWrapper, TransformerMixin): ...@@ -458,13 +463,22 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
logger.debug(f"{_frmt(self)}.fit") logger.debug(f"{_frmt(self)}.fit")
def _fit(X, y, **fit_params): def _fit(X, y, **fit_params):
self.estimator = self.estimator.fit(X, y, **fit_params) try:
self.estimator = self.estimator.fit(X, y, **fit_params)
except Exception as e:
raise RuntimeError(
f"Something went wrong when fitting {self.estimator} "
f"from {self}"
) from e
copy_learned_attributes(self.estimator, self) copy_learned_attributes(self.estimator, self)
return self.estimator return self.estimator
# change the name to have a better name in dask graphs # change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit" _fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,) self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None: if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self._make_dask_resource_tag( self.resource_tags[self._dask_state] = self._make_dask_resource_tag(
...@@ -582,7 +596,7 @@ def wrap(bases, estimator=None, **kwargs): ...@@ -582,7 +596,7 @@ def wrap(bases, estimator=None, **kwargs):
estimator, leftover = _wrap(estimator, **kwargs) estimator, leftover = _wrap(estimator, **kwargs)
if leftover: if leftover:
raise ValueError(f"Got extra kwargs that were not consumed: {kwargs}") raise ValueError(f"Got extra kwargs that were not consumed: {leftover}")
return estimator return estimator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment