Commit 56847997 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'sample_fta' into 'master'

Handled failed processing (Failure to Acquire) in the wrappers

Closes #32

See merge request !66
parents e320416b f4bfce4d
Pipeline #51180 passed with stages
in 8 minutes and 31 seconds
......@@ -73,6 +73,25 @@ class DummyTransformer(TransformerMixin, BaseEstimator):
return {"stateless": True, "requires_fit": False}
class HalfFailingDummyTransformer(DummyTransformer):
"""Transformer that fails for some samples (all even indices fail)"""
def transform(self, X):
X = check_array(X, force_all_finite=False)
X = _offset_add_func(X)
output = []
for i, x in enumerate(X):
output.append(x if i % 2 else None)
return output
class FullFailingDummyTransformer(DummyTransformer):
"""Transformer that fails for all samples"""
def transform(self, X):
return [None] * len(X)
def _assert_all_close_numpy_array(oracle, result):
oracle, result = np.array(oracle), np.array(result)
assert (
......@@ -119,6 +138,128 @@ def test_fittable_sample_transformer():
_assert_all_close_numpy_array(X + 1, [s.data for s in features])
def test_failing_sample_transformer():
X = np.zeros(shape=(10, 2))
samples = [mario.Sample(data) for i, data in enumerate(X)]
expected = np.full_like(X, 2, dtype=np.object)
expected[::2] = None
expected[1::4] = None
transformer = Pipeline(
[
("1", mario.wrap([HalfFailingDummyTransformer, "sample"])),
("2", mario.wrap([HalfFailingDummyTransformer, "sample"])),
]
)
features = transformer.transform(samples)
features = [f.data for f in features]
assert len(expected) == len(
features
), f"Expected: {len(expected)} but got: {len(features)}"
assert all(
(e == f).all() for e, f in zip(expected, features)
), f"Expected: {expected} but got: {features}"
samples = [mario.Sample(data) for data in X]
expected = [None] * X.shape[0]
transformer = Pipeline(
[
("1", mario.wrap([FullFailingDummyTransformer, "sample"])),
("2", mario.wrap([FullFailingDummyTransformer, "sample"])),
]
)
features = transformer.transform(samples)
features = [f.data for f in features]
assert len(expected) == len(
features
), f"Expected: {len(expected)} but got: {len(features)}"
assert all(
e == f for e, f in zip(expected, features)
), f"Expected: {expected} but got: {features}"
def test_failing_checkpoint_transformer():
X = np.zeros(shape=(10, 2))
samples = [mario.Sample(data, key=i) for i, data in enumerate(X)]
expected = np.full_like(X, 2)
expected[::2] = None
expected[1::4] = None
expected = list(expected)
with tempfile.TemporaryDirectory() as d:
features_dir_1 = os.path.join(d, "features_1")
features_dir_2 = os.path.join(d, "features_2")
transformer = Pipeline(
[
(
"1",
mario.wrap(
[HalfFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_1,
),
),
(
"2",
mario.wrap(
[HalfFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_2,
),
),
]
)
features = transformer.transform(samples)
np_features = np.array(
[
np.full(X.shape[1], np.nan) if f.data is None else f.data
for f in features
]
)
assert len(expected) == len(
np_features
), f"Expected: {len(expected)} but got: {len(np_features)}"
assert np.allclose(
expected, np_features, equal_nan=True
), f"Expected: {expected} but got: {np_features}"
samples = [mario.Sample(data, key=i) for i, data in enumerate(X)]
expected = [None] * X.shape[0]
with tempfile.TemporaryDirectory() as d:
features_dir_1 = os.path.join(d, "features_1")
features_dir_2 = os.path.join(d, "features_2")
transformer = Pipeline(
[
(
"1",
mario.wrap(
[FullFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_1,
),
),
(
"2",
mario.wrap(
[FullFailingDummyTransformer, "sample", "checkpoint"],
features_dir=features_dir_2,
),
),
]
)
features = transformer.transform(samples)
assert len(expected) == len(
features
), f"Expected: {len(expected)} but got: {len(features)}"
assert all(
e == f.data for e, f in zip(expected, features)
), f"Expected: {expected} but got: {features}"
def _assert_checkpoints(features, oracle, model_path, features_dir, stateless):
_assert_all_close_numpy_array(oracle, [s.data for s in features])
if stateless:
......
......@@ -82,9 +82,21 @@ class DelayedSamplesCall:
def __call__(self, index):
if self.output is None:
X = SampleBatch(self.samples, sample_attribute=self.sample_attribute)
self.output = self.func(X)
_check_n_input_output(self.samples, self.output, self.func_name)
# Isolate invalid samples (when previous transformers returned None)
invalid_ids = [i for i, s in enumerate(self.samples) if s.data is None]
valid_samples = [s for s in self.samples if s.data is not None]
# Process only the valid samples
if len(valid_samples) > 0:
X = SampleBatch(valid_samples, sample_attribute=self.sample_attribute)
self.output = self.func(X)
_check_n_input_output(valid_samples, self.output, self.func_name)
if self.output is None:
self.output = [None] * len(valid_samples)
# Rebuild the full batch of samples (include the previously failed)
if len(invalid_ids) > 0:
self.output = list(self.output)
for i in invalid_ids:
self.output.insert(i, None)
return self.output[index]
......@@ -311,8 +323,11 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
if should_compute:
feat = computed_features[com_feat_index]
com_feat_index += 1
# save the computed feature
if p is not None:
# save the computed feature when valid (not None)
if (
p is not None
and getattr(feat, self.sample_attribute) is not None
):
self.save(feat)
feat = self.load(s, p)
features.append(feat)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment