Commit de866405 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'update2' into 'master'

Two updates

See merge request !19
parents 25617068 dce35ed0
Pipeline #38700 failed with stages
in 7 minutes and 32 seconds
......@@ -11,6 +11,9 @@ from sklearn.base import TransformerMixin
from sklearn.pipeline import Pipeline
from dask import delayed
import dask.bag
import logging
logger = logging.getLogger(__name__)
def estimator_dask_it(
......@@ -226,6 +229,8 @@ class SampleMixin:
def transform(self, samples):
# Transform either samples or samplesets
logger.info(f"Transforming Sample/SampleSet: {self}")
if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments)
features = super().transform([s.data for s in samples], **kwargs)
......@@ -240,6 +245,8 @@ class SampleMixin:
def fit(self, samples, y=None):
logger.info(f"Fitting {self}")
# See: https://scikit-learn.org/stable/developers/develop.html
# if the estimator does not require fit or is stateless don't call fit
tags = self._get_tags()
......@@ -286,13 +293,14 @@ class CheckpointMixin:
return new_sample
def transform_one_sample_set(self, sample_set):
samples = [self.transform_one_sample(s) for s in sample_set.samples]
samples = [self.transform_one_sample(s) for s in sample_set]
return SampleSet(samples, parent=sample_set)
def transform(self, samples):
if not isinstance(samples, list):
raise ValueError("It's expected a list, not %s" % type(samples))
logger.info(f"Checkpointing Sample/SampleSet: {self}")
if isinstance(samples[0], Sample) or isinstance(samples[0], DelayedSample):
return [self.transform_one_sample(s) for s in samples]
elif isinstance(samples[0], SampleSet):
......@@ -324,7 +332,7 @@ class CheckpointMixin:
os.makedirs(os.path.dirname(path), exist_ok=True)
return self.save_func(sample.data, path)
elif isinstance(sample, SampleSet):
for s in sample.samples:
for s in sample:
path = self.make_path(s)
os.makedirs(os.path.dirname(path), exist_ok=True)
return self.save_func(s.data, path)
......
from collections.abc import MutableSequence
"""Base definition of sample"""
......@@ -72,11 +74,37 @@ class Sample:
_copy_attributes(self, kwargs)
class SampleSet:
"""A set of samples with extra attributes"""
class SampleSet(MutableSequence):
"""A set of samples with extra attributes
https://docs.python.org/3/library/collections.abc.html#collections-abstract-base-classes
"""
def __init__(self, samples, parent=None, **kwargs):
self.samples = samples
if parent is not None:
_copy_attributes(self, parent.__dict__)
_copy_attributes(self, kwargs)
def __len__(self):
return len(self.samples)
def __getitem__(self, item):
return self.samples.__getitem__(item)
def __setitem__(self, key, item):
if not isinstance(item, Sample):
raise ValueError(f"item should be of type Sample, not {item}")
return self.samples.__setitem__(key, item)
def __delitem__(self, item):
return self.samples.__delitem__(item)
def insert(self, index, item):
if not isinstance(item, Sample):
raise ValueError(f"item should be of type Sample, not {item}")
# if not item in self.samples:
self.samples.insert(index, item)
......@@ -352,3 +352,39 @@ def test_dask_checkpoint_transform_pipeline():
X_tr = estimator.transform(bag_transformer.transform(samples_transform))
assert len(estimator.dask_tags()) == 1
assert len(X_tr.compute(scheduler="single-threaded")) == 10
def test_checkpoint_transform_pipeline_with_sampleset():
def _run(dask_enabled):
X = np.ones(shape=(10, 2), dtype=int)
samples_transform = SampleSet(
[Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
)
offset = 2
oracle = X + offset
with tempfile.TemporaryDirectory() as d:
pipeline = Pipeline(
[(f"{i}", _build_transformer(d, i)) for i in range(offset)]
)
if dask_enabled:
pipeline = estimator_dask_it(pipeline)
transformed_samples = pipeline.transform([samples_transform]).compute(
scheduler="single-threaded"
)
else:
transformed_samples = pipeline.transform([samples_transform])
_assert_all_close_numpy_array(
oracle,
[
s.data
for sampleset in transformed_samples
for s in sampleset.samples
],
)
assert np.all([len(s) == 10 for s in transformed_samples])
_run(dask_enabled=True)
_run(dask_enabled=False)
from bob.pipelines.sample import Sample, SampleSet, DelayedSample
import numpy
from nose.tools import assert_raises
import copy
def test_sampleset_collection():
n_samples = 10
X = numpy.ones(shape=(n_samples, 2), dtype=int)
sampleset = SampleSet(
[Sample(data, key=str(i)) for i, data in enumerate(X)], key="1"
)
assert len(sampleset) == n_samples
# Testing insert
sample = Sample(X, key=100)
sampleset.insert(1, sample)
assert len(sampleset) == n_samples + 1
# Testing delete
del sampleset[0]
assert len(sampleset) == n_samples
# Testing exception
with assert_raises(ValueError):
sampleset.insert(1, 10)
# Testing set
sampleset[0] = copy.deepcopy(sample)
# Testing exception
with assert_raises(ValueError):
sampleset[0] = "xuxa"
# Testing iterator
for i in sampleset:
assert isinstance(i, Sample)
......@@ -27,14 +27,13 @@ requirements:
- setuptools {{ setuptools }}
- bob.extension
- bob.io.base
run:
- python
- setuptools
- dask
- dask-jobqueue
- numpy {{ numpy }}
- h5py
- scikit-learn >=0.22
run:
- python
- setuptools
test:
imports:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment