Commit 728499bd authored by Yannick DAYER's avatar Yannick DAYER

Allowed passing of failed processing information

parent e320416b
......@@ -73,6 +73,25 @@ class DummyTransformer(TransformerMixin, BaseEstimator):
return {"stateless": True, "requires_fit": False}
class HalfFailingDummyTransformer(DummyTransformer):
"""Transformer that fails for some samples (all even indices fail)"""
def transform(self, X):
X = check_array(X, force_all_finite=False)
X = _offset_add_func(X)
output = []
for i, x in enumerate(X):
output.append(x if i % 2 else None)
return output
class FullFailingDummyTransformer(DummyTransformer):
"""Transformer that fails for all samples"""
def transform(self, X):
return [None] * len(X)
def _assert_all_close_numpy_array(oracle, result):
oracle, result = np.array(oracle), np.array(result)
assert (
......@@ -119,6 +138,130 @@ def test_fittable_sample_transformer():
_assert_all_close_numpy_array(X + 1, [s.data for s in features])
def test_failing_sample_transformer():
X = np.zeros(shape=(10, 2))
samples = [mario.Sample(data, key=i) for i, data in enumerate(X)]
expected = np.full_like(X, 2)
expected[::2] = np.nan
expected[1::4] = np.nan
transformer = Pipeline(
[
("1", mario.wrap([HalfFailingDummyTransformer, "sample"])),
("2", mario.wrap([HalfFailingDummyTransformer, "sample"])),
]
)
features = transformer.transform(samples)
np_features = np.array(
[np.full(X.shape[1], np.nan) if f.data is None else f.data for f in features]
)
assert len(expected) == len(
np_features
), f"Expected: {len(expected)} but got: {len(np_features)}"
assert np.allclose(
expected, np_features, equal_nan=True
), f"Expected: {expected} but got: {np_features}"
samples = [mario.Sample(data) for data in X]
expected = np.full(X.shape[0], np.nan)
transformer = Pipeline(
[
("1", mario.wrap([FullFailingDummyTransformer, "sample"])),
("2", mario.wrap([FullFailingDummyTransformer, "sample"])),
]
)
features = transformer.transform(samples)
np_features = np.array([np.nan if not f.data else f.data for f in features])
assert len(expected) == len(
np_features
), f"Expected: {len(expected)} but got: {len(np_features)}"
assert np.allclose(
expected, np_features, equal_nan=True
), f"Expected: {expected} but got: {np_features}"
def test_failing_checkpoint_transformer():
X = np.zeros(shape=(10, 2))
samples = [mario.Sample(data, key=i) for i, data in enumerate(X)]
expected = np.full_like(X, 2)
expected[::2] = np.nan
expected[1::4] = np.nan
with tempfile.TemporaryDirectory() as d:
features_dir_1 = os.path.join(d, "features_1")
features_dir_2 = os.path.join(d, "features_2")
transformer = Pipeline(
[
(
"1",
mario.wrap(
[HalfFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_1,
),
),
(
"2",
mario.wrap(
[HalfFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_2,
),
),
]
)
features = transformer.transform(samples)
np_features = np.array(
[
np.full(X.shape[1], np.nan) if f.data is None else f.data
for f in features
]
)
assert len(expected) == len(
np_features
), f"Expected: {len(expected)} but got: {len(np_features)}"
assert np.allclose(
expected, np_features, equal_nan=True
), f"Expected: {expected} but got: {np_features}"
samples = [mario.Sample(data, key=i) for i, data in enumerate(X)]
expected = np.full(X.shape[0], np.nan)
with tempfile.TemporaryDirectory() as d:
features_dir_1 = os.path.join(d, "features_1")
features_dir_2 = os.path.join(d, "features_2")
transformer = Pipeline(
[
(
"1",
mario.wrap(
[FullFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_1,
),
),
(
"2",
mario.wrap(
[FullFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_2,
),
),
]
)
features = transformer.transform(samples)
np_features = np.array([f.data for f in features])
assert len(expected) == len(
np_features
), f"Expected: {len(expected)} but got: {len(np_features)}"
assert np.allclose(
expected, np_features, equal_nan=True
), f"Expected: {expected} but got: {np_features}"
def _assert_checkpoints(features, oracle, model_path, features_dir, stateless):
_assert_all_close_numpy_array(oracle, [s.data for s in features])
if stateless:
......
"""Scikit-learn Estimator Wrappers."""
import logging
import os
import numpy
from functools import partial
......@@ -82,8 +83,18 @@ class DelayedSamplesCall:
def __call__(self, index):
if self.output is None:
X = SampleBatch(self.samples, sample_attribute=self.sample_attribute)
self.output = self.func(X)
# Isolate invalid samples (when previous transformers returned None)
invalid_ids = [i for i, s in enumerate(self.samples) if s.data is None]
valid_samples = [s for s in self.samples if s.data is not None]
# Process only the valid samples
if len(valid_samples) > 0:
X = SampleBatch(valid_samples, sample_attribute=self.sample_attribute)
self.output = self.func(X)
if self.output is None:
self.output = [None] * len(valid_samples)
# Rebuild the full batch of samples (including previously failed)
for i in invalid_ids:
self.output = numpy.insert(self.output, i, None, axis=0)
_check_n_input_output(self.samples, self.output, self.func_name)
return self.output[index]
......@@ -311,8 +322,12 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if should_compute:
feat = computed_features[com_feat_index]
com_feat_index += 1
# save the computed feature
if p is not None:
# save the computed feature when valid (not NaN)
if (
p is not None
and getattr(feat, self.sample_attribute) is not None
and not numpy.isnan(getattr(feat, self.sample_attribute)).any()
):
self.save(feat)
feat = self.load(s, p)
features.append(feat)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment