Commit 17a39fe3 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[CheckpointWrapper] Allow custom save and load functions through estimator tags

parent e2459dc5
......@@ -14,6 +14,9 @@ from sklearn.base import BaseEstimator
from sklearn.pipeline import _name_estimators
from sklearn.utils.metaestimators import _BaseComposition
from bob.io.base import load
from bob.io.base import save
from .sample import SAMPLE_DATA_ATTRS
from .sample import _ReprMixin
from .utils import is_estimator_stateless
......@@ -176,8 +179,12 @@ class Block(_ReprMixin):
self.model_path = model_path
self.features_dir = features_dir
self.extension = extension
self.save_func = save_func or partial(np.save, allow_pickle=False)
self.load_func = load_func or np.load
self.save_func = (
save_func or estimator._get_tags().get("bob_features_save_fn") or save
)
self.load_func = (
load_func or estimator._get_tags().get("bob_features_load_fn") or load
)
self.dataset_map = dataset_map
self.input_dask_array = input_dask_array
self.fit_kwargs = fit_kwargs or {}
......@@ -204,8 +211,8 @@ class Block(_ReprMixin):
def save(self, key, data):
path = self.make_path(key)
os.makedirs(os.path.dirname(path), exist_ok=True)
# this should be save_func(path, data) so it's compatible with np.save
return self.save_func(path, data)
# this should be save_func(data, path) so it's compatible with bob.io.base.save
return self.save_func(data, path)
def load(self, key):
path = self.make_path(key)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment