Skip to content
Snippets Groups Projects
Commit 0e43d548 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[CheckpointWrapper] only checkpoint the transform method, automatically add...

[CheckpointWrapper] only checkpoint the transform method, automatically add .pkl extension to model names
parent 99fa97c0
No related branches found
No related tags found
1 merge request!52[CheckpointWrapper] Allow custom save and load functions through estimator tags
...@@ -305,16 +305,16 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): ...@@ -305,16 +305,16 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self._checkpoint_transform(samples, "transform") return self._checkpoint_transform(samples, "transform")
def decision_function(self, samples): def decision_function(self, samples):
return self._checkpoint_transform(samples, "decision_function") return self.estimator.decision_function(samples)
def predict(self, samples): def predict(self, samples):
return self._checkpoint_transform(samples, "predict") return self.estimator.predict(samples)
def predict_proba(self, samples): def predict_proba(self, samples):
return self._checkpoint_transform(samples, "predict_proba") return self.estimator.predict_proba(samples)
def score(self, samples): def score(self, samples):
return self._checkpoint_transform(samples, "score") return self.estimator.score(samples)
def fit(self, samples, y=None): def fit(self, samples, y=None):
...@@ -567,7 +567,7 @@ def wrap(bases, estimator=None, **kwargs): ...@@ -567,7 +567,7 @@ def wrap(bases, estimator=None, **kwargs):
if features_dir is not None: if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name) new_kwargs["features_dir"] = os.path.join(features_dir, name)
if model_path is not None: if model_path is not None:
new_kwargs["model_path"] = os.path.join(model_path, name) new_kwargs["model_path"] = os.path.join(model_path, f"{name}.pkl")
trans, leftover = _wrap(trans, **new_kwargs) trans, leftover = _wrap(trans, **new_kwargs)
estimator.steps[idx] = (name, trans) estimator.steps[idx] = (name, trans)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment