Commit 0e43d548 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[CheckpointWrapper] only checkpoint the transform method, automatically add...

[CheckpointWrapper] only checkpoint the transform method, automatically add .pkl extension to model names
parent 99fa97c0
......@@ -305,16 +305,16 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self._checkpoint_transform(samples, "transform")
def decision_function(self, samples):
return self._checkpoint_transform(samples, "decision_function")
return self.estimator.decision_function(samples)
def predict(self, samples):
return self._checkpoint_transform(samples, "predict")
return self.estimator.predict(samples)
def predict_proba(self, samples):
return self._checkpoint_transform(samples, "predict_proba")
return self.estimator.predict_proba(samples)
def score(self, samples):
return self._checkpoint_transform(samples, "score")
return self.estimator.score(samples)
def fit(self, samples, y=None):
......@@ -567,7 +567,7 @@ def wrap(bases, estimator=None, **kwargs):
if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name)
if model_path is not None:
new_kwargs["model_path"] = os.path.join(model_path, name)
new_kwargs["model_path"] = os.path.join(model_path, f"{name}.pkl")
trans, leftover = _wrap(trans, **new_kwargs)
estimator.steps[idx] = (name, trans)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment