Skip to content
Snippets Groups Projects
Commit 22f98128 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Update the innermost estimator however low it is.

parent 2933d425
Branches
Tags
1 merge request!85Prevent a reference invalidation when wrapped with sample and checkpoint.
Pipeline #59856 passed
......@@ -571,20 +571,19 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if is_estimator_stateless(self.estimator):
return self
with open(self.model_path, "rb") as f:
estimator = cloudpickle.load(f)
# we don't do self.estimator = estimator, because self.estimator
# might be used elsewhere
if isinstance(self.estimator, SampleWrapper):
# Fix for the case where we are checkpointing a sample-wrapped estimator
self.estimator.estimator.__dict__.update(
estimator.estimator.__dict__
)
# Update other keys in __dict__
for k, v in estimator.__dict__.items():
loaded_estimator = cloudpickle.load(f)
estimator = self.estimator
# For the update, ensure that we have the estimator, not a wrapper
while hasattr(estimator, "estimator"):
# Update this estimator __dict__ except for the attribute "estimator"
for k, v in loaded_estimator.__dict__.items():
if k != "estimator":
setattr(self.estimator, k, v)
else:
self.estimator.__dict__.update(estimator.__dict__)
estimator.__dict__[k] = v
estimator = estimator.estimator
loaded_estimator = loaded_estimator.estimator
# we don't do self.estimator = loaded_estimator, because self.estimator
# might be used elsewhere
estimator.__dict__.update(loaded_estimator.__dict__)
return self
def save_model(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment