Commit 60cf9996 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'dask_client_names' into 'master'

Dask client names

See merge request !59
parents e2459dc5 83448a4a
Pipeline #46529 passed with stages
in 4 minutes and 59 seconds
......@@ -15,7 +15,20 @@ __path__ = extend_path(__path__, __name__)
# cls=ResourceOption,
# )
VALID_DASK_CLIENT_STRINGS = ("single-threaded", "sync", "threaded", "processes")
try:
import dask
VALID_DASK_CLIENT_STRINGS = dask.base.named_schedulers
except (ModuleNotFoundError, ImportError):
VALID_DASK_CLIENT_STRINGS = (
"sync",
"synchronous",
"single-threaded",
"threads",
"threading",
"processes",
"multiprocessing",
)
def dask_get_partition_size(cluster, n_objects, lower_bound=200):
......
......@@ -16,6 +16,7 @@ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
import bob.io.base
import bob.pipelines as mario
......@@ -178,8 +179,8 @@ def test_dataset_pipeline_with_checkpoints():
)
):
path = os.path.join(pca_features, path)
assert path.endswith(f"{i}.npy"), path
np.testing.assert_array_equal(np.load(path).shape, (3,))
assert path.endswith(f"{i}.hdf5"), path
np.testing.assert_array_equal(bob.io.base.load(path).shape, (3,))
# now this time it should load features
# delete one of the features
......
......@@ -14,6 +14,9 @@ from sklearn.base import BaseEstimator
from sklearn.pipeline import _name_estimators
from sklearn.utils.metaestimators import _BaseComposition
from bob.io.base import load
from bob.io.base import save
from .sample import SAMPLE_DATA_ATTRS
from .sample import _ReprMixin
from .utils import is_estimator_stateless
......@@ -151,7 +154,7 @@ class Block(_ReprMixin):
estimator_name=None,
model_path=None,
features_dir=None,
extension=".npy",
extension=".hdf5",
save_func=None,
load_func=None,
dataset_map=None,
......@@ -176,8 +179,18 @@ class Block(_ReprMixin):
self.model_path = model_path
self.features_dir = features_dir
self.extension = extension
self.save_func = save_func or partial(np.save, allow_pickle=False)
self.load_func = load_func or np.load
estimator_save_fn = (
None
if estimator is None
else estimator._get_tags().get("bob_features_save_fn")
)
estimator_load_fn = (
None
if estimator is None
else estimator._get_tags().get("bob_features_load_fn")
)
self.save_func = save_func or estimator_save_fn or save
self.load_func = load_func or estimator_load_fn or load
self.dataset_map = dataset_map
self.input_dask_array = input_dask_array
self.fit_kwargs = fit_kwargs or {}
......@@ -204,8 +217,8 @@ class Block(_ReprMixin):
def save(self, key, data):
path = self.make_path(key)
os.makedirs(os.path.dirname(path), exist_ok=True)
# this should be save_func(path, data) so it's compatible with np.save
return self.save_func(path, data)
# this should be save_func(data, path) so it's compatible with bob.io.base.save
return self.save_func(data, path)
def load(self, key):
path = self.make_path(key)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment