Skip to content
Snippets Groups Projects
Commit 558c68f9 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Merge branch 'support-passthrough' into 'master'

Support a pipeline with "passthrough" or None as estimator

See merge request !302
parents 0988e735 a2df810f
No related branches found
No related tags found
1 merge request!302Support a pipeline with "passthrough" or None as estimator
Pipeline #63223 passed
...@@ -198,12 +198,16 @@ def check_valid_pipeline(pipeline_simple): ...@@ -198,12 +198,16 @@ def check_valid_pipeline(pipeline_simple):
""" """
# CHECKING THE TRANSFORMER # CHECKING THE TRANSFORMER
# Checking if it's a Scikit Pipeline or a estimator # Checking if it's a Scikit Pipeline or an estimator
if isinstance(pipeline_simple.transformer, Pipeline): if isinstance(pipeline_simple.transformer, Pipeline):
# Checking if all steps are wrapped as samples, if not, we should wrap them # Checking if all steps are wrapped as samples, if not, we should wrap them
for p in pipeline_simple.transformer: for p in pipeline_simple.transformer:
if not is_instance_nested(p, "estimator", SampleWrapper): if (
not is_instance_nested(p, "estimator", SampleWrapper)
and type(p) is not str
and p is not None
):
wrap(["sample"], p) wrap(["sample"], p)
# In this case it can be a simple estimator. AND # In this case it can be a simple estimator. AND
......
...@@ -518,6 +518,26 @@ def test_database_full_failure(): ...@@ -518,6 +518,26 @@ def test_database_full_failure():
_run_with_failure(False, sporadic_fail=False) _run_with_failure(False, sporadic_fail=False)
def test_pipeline_simple_passthrough():
"""Ensure that PipelineSimple accepts a passthrough Estimator."""
passthrough = make_pipeline(None)
pipeline = PipelineSimple(passthrough, Distance())
assert isinstance(pipeline, PipelineSimple)
pipeline_with_passthrough = make_pipeline("passthrough")
pipeline = PipelineSimple(pipeline_with_passthrough, Distance())
assert isinstance(pipeline, PipelineSimple)
db = DummyDatabase()
scores = pipeline(
db.background_model_samples(), db.references(), db.probes()
)
assert len(scores) == 10
for sample_scores in scores:
assert len(sample_scores) == 10
for score in sample_scores:
assert isinstance(score.data, float)
def _create_test_config(path): def _create_test_config(path):
with open(path, "w") as f: with open(path, "w") as f:
f.write( f.write(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment