Commit a61d9797 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[xarray][samples_to_dataset] Make sure it is working with new DelayedSamples

parent 98832088
Pipeline #45916 passed with stage
in 3 minutes and 37 seconds
import os
import tempfile
from functools import partial
import dask
import dask_ml.decomposition
import dask_ml.preprocessing
......@@ -17,9 +19,25 @@ from sklearn.preprocessing import StandardScaler
import bob.pipelines as mario
def _build_toy_samples():
def _build_toy_samples(delayed=False):
X = np.ones(shape=(10, 5), dtype=int)
samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
if delayed:
def _load(index, attr):
if attr == "data":
return X[index]
if attr == "key":
return str(index)
samples = [
mario.DelayedSample(
partial(_load, i, "data"),
delayed_attributes=dict(key=partial(_load, i, "key")),
)
for i in range(len(X))
]
else:
samples = [mario.Sample(data, key=str(i)) for i, data in enumerate(X)]
return X, samples
......@@ -31,17 +49,46 @@ def test_samples_to_dataset():
np.testing.assert_array_equal(dataset["key"], [str(i) for i in range(10)])
def _build_iris_dataset(shuffle=False):
def test_delayed_samples_to_dataset():
X, samples = _build_toy_samples(delayed=True)
dataset = mario.xr.samples_to_dataset(samples)
assert dataset.dims == {"sample": X.shape[0], "dim_0": X.shape[1]}, dataset.dims
np.testing.assert_array_equal(dataset["data"], X)
np.testing.assert_array_equal(dataset["key"], [str(i) for i in range(10)])
def _build_iris_dataset(shuffle=False, delayed=False):
iris = datasets.load_iris()
X = iris.data
keys = [str(k) for k in range(len(X))]
samples = [
mario.Sample(x, target=y, key=k)
for x, y, k in zip(iris.data, iris.target, keys)
]
if delayed:
def _load(index, attr):
if attr == "data":
return X[index]
if attr == "key":
return str(index)
if attr == "target":
return iris.target[index]
samples = [
mario.DelayedSample(
partial(_load, i, "data"),
delayed_attributes=dict(
key=partial(_load, i, "key"),
target=partial(_load, i, "target"),
),
)
for i in range(len(X))
]
else:
samples = [
mario.Sample(x, target=y, key=k)
for x, y, k in zip(iris.data, iris.target, keys)
]
meta = xr.DataArray(X[0], dims=("feature",))
dataset = mario.xr.samples_to_dataset(
samples, meta=meta, npartitions=3, shuffle=shuffle
......@@ -50,20 +97,21 @@ def _build_iris_dataset(shuffle=False):
def test_dataset_pipeline():
ds = _build_iris_dataset()
estimator = mario.xr.DatasetPipeline(
[
PCA(n_components=0.99),
{
"estimator": LinearDiscriminantAnalysis(),
"fit_input": ["data", "target"],
},
]
)
for delayed in (True, False):
ds = _build_iris_dataset(delayed=delayed)
estimator = mario.xr.DatasetPipeline(
[
PCA(n_components=0.99),
{
"estimator": LinearDiscriminantAnalysis(),
"fit_input": ["data", "target"],
},
]
)
estimator = estimator.fit(ds)
ds = estimator.decision_function(ds)
ds.compute()
estimator = estimator.fit(ds)
ds = estimator.decision_function(ds)
ds.compute()
def test_dataset_pipeline_with_shapes():
......
......@@ -21,20 +21,41 @@ from .utils import is_estimator_stateless
logger = logging.getLogger(__name__)
def _one_sample_to_dataset(sample, meta=None):
dataset = {k: v for k, v in sample.__dict__.items() if k not in SAMPLE_DATA_ATTRS}
def _load_fn_to_xarray(load_fn, meta=None):
if meta is None:
meta = sample.data
dataset["data"] = dask.array.from_delayed(
dask.delayed(sample).data, meta.shape, dtype=meta.dtype, name=False
meta = np.array(load_fn())
da = dask.array.from_delayed(
dask.delayed(load_fn)(), meta.shape, dtype=meta.dtype, name=False
)
try:
dims = meta.dims
except Exception:
dims = None
dataset["data"] = xr.DataArray(dataset["data"], dims=dims)
return xr.Dataset(dataset).chunk()
xa = xr.DataArray(da, dims=dims)
return xa, meta
def _one_sample_to_dataset(sample, meta=None):
dataset = {}
delayed_attributes = getattr(sample, "delayed_attributes", None) or {}
for k in sample.__dict__:
if k in SAMPLE_DATA_ATTRS or k in delayed_attributes:
continue
dataset[k] = getattr(sample, k)
meta = meta or {}
for k in ["data"] + list(delayed_attributes.keys()):
attr_meta = meta.get(k)
attr_array, attr_meta = _load_fn_to_xarray(
partial(getattr, sample, k), meta=attr_meta
)
meta[k] = attr_meta
dataset[k] = attr_array
return xr.Dataset(dataset).chunk(), meta
def samples_to_dataset(samples, meta=None, npartitions=48, shuffle=False):
......@@ -58,13 +79,20 @@ def samples_to_dataset(samples, meta=None, npartitions=48, shuffle=False):
``xarray.Dataset``
The constructed dataset with at least a ``data`` variable.
"""
if meta is None:
dataset = _one_sample_to_dataset(samples[0])
meta = dataset["data"]
if meta is not None and not isinstance(meta, dict):
meta = dict(data=meta)
delayed_attributes = getattr(samples[0], "delayed_attributes", None) or {}
if meta is None or not all(
k in meta for k in ["data"] + list(delayed_attributes.keys())
):
dataset, meta = _one_sample_to_dataset(samples[0])
if shuffle:
random.shuffle(samples)
dataset = xr.concat(
[_one_sample_to_dataset(s, meta=meta) for s in samples], dim="sample"
[_one_sample_to_dataset(s, meta=meta)[0] for s in samples], dim="sample"
)
if npartitions is not None:
dataset = dataset.chunk({"sample": max(1, len(samples) // npartitions)})
......@@ -431,7 +459,9 @@ class DatasetPipeline(_BaseComposition):
try:
ds = block.dataset_map(ds)
except Exception as e:
raise RuntimeError(f"Could not map ds {ds}\n with {block.dataset_map}") from e
raise RuntimeError(
f"Could not map ds {ds}\n with {block.dataset_map}"
) from e
continue
if do_fit:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment