Skip to content
Snippets Groups Projects

Created a function checking if a Scikit learn pipeline is wrapped

Merged Tiago de Freitas Pereira requested to merge wrap into master

Files

+ 27
1
@@ -2,10 +2,36 @@ import random
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer
import bob.pipelines as mario
from bob.pipelines import Sample, SampleSet
from bob.pipelines.utils import flatten_samplesets
from bob.pipelines.utils import flatten_samplesets, is_estimator_wrapped
from bob.pipelines.wrappers import CheckpointWrapper, SampleWrapper, wrap
def test_is_estimator_wrapped():
def do_something(X):
return X
my_pipe = make_pipeline(
FunctionTransformer(do_something), FunctionTransformer(do_something)
)
assert is_estimator_wrapped(my_pipe, SampleWrapper) is False
assert is_estimator_wrapped(my_pipe, CheckpointWrapper) is False
# Sample wrap
my_pipe = wrap(["sample"], my_pipe)
assert is_estimator_wrapped(my_pipe, SampleWrapper)
assert is_estimator_wrapped(my_pipe, CheckpointWrapper) is False
# Checkpoint wrap
my_pipe = wrap(["checkpoint"], my_pipe)
assert is_estimator_wrapped(my_pipe, SampleWrapper)
assert is_estimator_wrapped(my_pipe, CheckpointWrapper)
def test_isinstance_nested():
Loading