Skip to content
Snippets Groups Projects

breaking: checkpoint the inner estimator only

Merged Amir MOHAMMADI requested to merge checkpoint-wrapper into master
1 file
+ 5
3
Compare changes
  • Side-by-side
  • Inline
+ 5
3
@@ -570,8 +570,10 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if is_estimator_stateless(self.estimator):
return self
with open(self.model_path, "rb") as f:
model = cloudpickle.load(f)
self.__dict__.update(model.__dict__)
estimator = cloudpickle.load(f)
# we don't do self.estimator = estimator, because self.estimator
# might be used elsewhere
self.estimator.__dict__.update(estimator.__dict__)
return self
def save_model(self):
@@ -579,7 +581,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
with open(self.model_path, "wb") as f:
cloudpickle.dump(self, f)
cloudpickle.dump(self.estimator, f)
return self
Loading