Skip to content
Snippets Groups Projects

Improvements on CheckpointWrapper

All threads resolved!

Files

@@ -11,8 +11,9 @@ from sklearn.preprocessing import FunctionTransformer
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from bob.pipelines.utils import hash_string
import bob.pipelines as mario
import tempfile
def _offset_add_func(X, offset=1):
@@ -175,6 +176,22 @@ def test_checkpoint_function_sample_transfomer():
features = transformer.transform(samples)
_assert_all_close_numpy_array(oracle, [s.data for s in features])
# test when both model_path and features_dir is None
with tempfile.TemporaryDirectory() as dir_name:
transformer = mario.wrap(
[FunctionTransformer, "sample", "checkpoint"],
func=_offset_add_func,
kw_args=dict(offset=offset),
validate=True,
features_dir=dir_name,
hash_fn=hash_string,
)
features = transformer.transform(samples)
# Checking if we have 8 chars in the second level
assert len(features[0].load.args[0].split("/")[-2]) == 8
_assert_all_close_numpy_array(oracle, [s.data for s in features])
def test_checkpoint_fittable_sample_transformer():
X = np.ones(shape=(10, 2), dtype=int)
Loading