diff --git a/bob/bio/base/pipelines/pipelines.py b/bob/bio/base/pipelines/pipelines.py index 3b13da245914cf4e7322b65dba1cd8474f024d7a..6cb48bcd8ba8220ad3fcd89285649dd11cd54136 100644 --- a/bob/bio/base/pipelines/pipelines.py +++ b/bob/bio/base/pipelines/pipelines.py @@ -198,12 +198,16 @@ def check_valid_pipeline(pipeline_simple): """ # CHECKING THE TRANSFORMER - # Checking if it's a Scikit Pipeline or a estimator + # Checking if it's a Scikit Pipeline or an estimator if isinstance(pipeline_simple.transformer, Pipeline): # Checking if all steps are wrapped as samples, if not, we should wrap them for p in pipeline_simple.transformer: - if not is_instance_nested(p, "estimator", SampleWrapper): + if ( + not is_instance_nested(p, "estimator", SampleWrapper) + and type(p) is not str + and p is not None + ): wrap(["sample"], p) # In this case it can be a simple estimator. AND diff --git a/bob/bio/base/test/test_pipeline_simple.py b/bob/bio/base/test/test_pipeline_simple.py index 746aa3360364cc38152b9ea380233c435ab508cf..7e5307b0b76abd0c362cdf3dc0a7200eb44b325b 100644 --- a/bob/bio/base/test/test_pipeline_simple.py +++ b/bob/bio/base/test/test_pipeline_simple.py @@ -518,6 +518,26 @@ def test_database_full_failure(): _run_with_failure(False, sporadic_fail=False) +def test_pipeline_simple_passthrough(): + """Ensure that PipelineSimple accepts a passthrough Estimator.""" + passthrough = make_pipeline(None) + pipeline = PipelineSimple(passthrough, Distance()) + assert isinstance(pipeline, PipelineSimple) + + pipeline_with_passthrough = make_pipeline("passthrough") + pipeline = PipelineSimple(pipeline_with_passthrough, Distance()) + assert isinstance(pipeline, PipelineSimple) + db = DummyDatabase() + scores = pipeline( + db.background_model_samples(), db.references(), db.probes() + ) + assert len(scores) == 10 + for sample_scores in scores: + assert len(sample_scores) == 10 + for score in sample_scores: + assert isinstance(score.data, float) + + def _create_test_config(path): with open(path, "w") as f: f.write(