Skip to content
Snippets Groups Projects

replace is_estimator_stateless with estimator_requires_fit

Merged Amir MOHAMMADI requested to merge stateless-is-requires-fit into master
12 files
+ 34
81
Compare changes
  • Side-by-side
  • Inline

Files

@@ -73,7 +73,7 @@ class DummyTransformer(TransformerMixin, BaseEstimator):
return _offset_add_func(X)
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
return {"requires_fit": False}
class HalfFailingDummyTransformer(DummyTransformer):
@@ -110,8 +110,6 @@ class DummyWithTags(DummyTransformer):
def _more_tags(self):
return {
"stateless": False,
"requires_fit": True,
"bob_output": "annotations",
"bob_transform_extra_input": (
("extra_arg_1", "data"),
@@ -136,8 +134,6 @@ class DummyWithTagsNotData(DummyTransformer):
def _more_tags(self):
return {
"stateless": False,
"requires_fit": True,
"bob_output": "annotations_2",
"bob_input": "annotations",
"bob_transform_extra_input": (
@@ -165,8 +161,6 @@ class DummyWithDask(DummyTransformer):
def _more_tags(self):
return {
"stateless": False,
"requires_fit": True,
"bob_output": "annotations",
"bob_fit_supports_dask_array": True,
}
@@ -454,9 +448,11 @@ def test_failing_checkpoint_transformer():
), f"Expected: {expected} but got: {features}"
def _assert_checkpoints(features, oracle, model_path, features_dir, stateless):
def _assert_checkpoints(
features, oracle, model_path, features_dir, not_requires_fit
):
_assert_all_close_numpy_array(oracle, [s.data for s in features])
if stateless:
if not_requires_fit:
assert not os.path.exists(model_path)
else:
assert os.path.exists(model_path), os.listdir(
@@ -546,7 +542,7 @@ def test_checkpoint_fittable_sample_transformer():
model_path=model_path,
features_dir=features_dir,
)
assert not mario.utils.is_estimator_stateless(transformer)
assert mario.utils.estimator_requires_fit(transformer)
features = transformer.fit(samples).transform(samples)
_assert_checkpoints(features, oracle, model_path, features_dir, False)
Loading