Commit 83448a4a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

[xarray] fix tests

parent 29a970c8
Pipeline #46439 passed with stage
in 8 minutes and 2 seconds
......@@ -16,6 +16,7 @@ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
import bob.io.base
import bob.pipelines as mario
......@@ -178,8 +179,8 @@ def test_dataset_pipeline_with_checkpoints():
)
):
path = os.path.join(pca_features, path)
assert path.endswith(f"{i}.npy"), path
np.testing.assert_array_equal(np.load(path).shape, (3,))
assert path.endswith(f"{i}.hdf5"), path
np.testing.assert_array_equal(bob.io.base.load(path).shape, (3,))
# now this time it should load features
# delete one of the features
......
......@@ -154,7 +154,7 @@ class Block(_ReprMixin):
estimator_name=None,
model_path=None,
features_dir=None,
extension=".npy",
extension=".hdf5",
save_func=None,
load_func=None,
dataset_map=None,
......@@ -179,12 +179,18 @@ class Block(_ReprMixin):
self.model_path = model_path
self.features_dir = features_dir
self.extension = extension
self.save_func = (
save_func or estimator._get_tags().get("bob_features_save_fn") or save
estimator_save_fn = (
None
if estimator is None
else estimator._get_tags().get("bob_features_save_fn")
)
self.load_func = (
load_func or estimator._get_tags().get("bob_features_load_fn") or load
estimator_load_fn = (
None
if estimator is None
else estimator._get_tags().get("bob_features_load_fn")
)
self.save_func = save_func or estimator_save_fn or save
self.load_func = load_func or estimator_load_fn or load
self.dataset_map = dataset_map
self.input_dask_array = input_dask_array
self.fit_kwargs = fit_kwargs or {}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment