Skip to content
Snippets Groups Projects

[CheckpointWrapper] Allow custom save and load functions through estimator tags

Merged Amir MOHAMMADI requested to merge estimator-tags into master
1 unresolved thread
1 file
+ 34
12
Compare changes
  • Side-by-side
  • Inline
+ 34
12
@@ -138,7 +138,8 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset,
Please register or sign in to reply
self._samples_transform(sset.samples, method_name),
parent=sset,
)
for sset in samples
]
@@ -248,8 +249,16 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.model_path = model_path
self.features_dir = features_dir
self.extension = extension
self.save_func = save_func or bob.io.base.save
self.load_func = load_func or bob.io.base.load
self.save_func = (
save_func
or estimator._get_tags().get("bob_features_save_fn")
or bob.io.base.save
)
self.load_func = (
load_func
or estimator._get_tags().get("bob_features_load_fn")
or bob.io.base.load
)
self.sample_attribute = sample_attribute
self.hash_fn = hash_fn
if model_path is None and features_dir is None:
@@ -305,16 +314,16 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
return self._checkpoint_transform(samples, "transform")
def decision_function(self, samples):
return self._checkpoint_transform(samples, "decision_function")
return self.estimator.decision_function(samples)
def predict(self, samples):
return self._checkpoint_transform(samples, "predict")
return self.estimator.predict(samples)
def predict_proba(self, samples):
return self._checkpoint_transform(samples, "predict_proba")
return self.estimator.predict_proba(samples)
def score(self, samples):
return self._checkpoint_transform(samples, "score")
return self.estimator.score(samples)
def fit(self, samples, y=None):
@@ -407,7 +416,11 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
@@ -458,13 +471,22 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
logger.debug(f"{_frmt(self)}.fit")
def _fit(X, y, **fit_params):
self.estimator = self.estimator.fit(X, y, **fit_params)
try:
self.estimator = self.estimator.fit(X, y, **fit_params)
except Exception as e:
raise RuntimeError(
f"Something went wrong when fitting {self.estimator} "
f"from {self}"
) from e
copy_learned_attributes(self.estimator, self)
return self.estimator
# change the name to have a better name in dask graphs
_fit.__name__ = f"{_frmt(self)}.fit"
self._dask_state = delayed(_fit)(X, y,)
self._dask_state = delayed(_fit)(
X,
y,
)
if self.fit_tag is not None:
self.resource_tags[self._dask_state] = self._make_dask_resource_tag(
@@ -567,7 +589,7 @@ def wrap(bases, estimator=None, **kwargs):
if features_dir is not None:
new_kwargs["features_dir"] = os.path.join(features_dir, name)
if model_path is not None:
new_kwargs["model_path"] = os.path.join(model_path, name)
new_kwargs["model_path"] = os.path.join(model_path, f"{name}.pkl")
trans, leftover = _wrap(trans, **new_kwargs)
estimator.steps[idx] = (name, trans)
@@ -582,7 +604,7 @@ def wrap(bases, estimator=None, **kwargs):
estimator, leftover = _wrap(estimator, **kwargs)
if leftover:
raise ValueError(f"Got extra kwargs that were not consumed: {kwargs}")
raise ValueError(f"Got extra kwargs that were not consumed: {leftover}")
return estimator
Loading