Skip to content
Snippets Groups Projects

Prevent a reference invalidation when wrapped with sample and checkpoint.

Merged Yannick DAYER requested to merge fix-chkpt-wrap into master
1 unresolved thread
Files
2
+ 13
4
@@ -571,10 +571,19 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if is_estimator_stateless(self.estimator):
return self
with open(self.model_path, "rb") as f:
estimator = cloudpickle.load(f)
# we don't do self.estimator = estimator, because self.estimator
loaded_estimator = cloudpickle.load(f)
estimator = self.estimator
# For the update, ensure that we have the estimator, not a wrapper
while hasattr(estimator, "estimator"):
# Update this estimator __dict__ except for the attribute "estimator"
for k, v in loaded_estimator.__dict__.items():
if k != "estimator":
estimator.__dict__[k] = v
estimator = estimator.estimator
loaded_estimator = loaded_estimator.estimator
# we don't do self.estimator = loaded_estimator, because self.estimator
# might be used elsewhere
self.estimator.__dict__.update(estimator.__dict__)
estimator.__dict__.update(loaded_estimator.__dict__)
return self
def save_model(self):
@@ -765,7 +774,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
return self
else:
logger.info(
"Ignoring conversion to dask array (checkpoint detected)"
f"Ignoring conversion to dask array (checkpoint detected at {model_path})"
)
def _fit(X, y, **fit_params):
Loading