Skip to content
Snippets Groups Projects

Handle estimator tags in wrapper classes

Merged Yannick DAYER requested to merge wrap-tags into master
1 file
+ 26
0
Compare changes
  • Side-by-side
  • Inline
+ 26
0
@@ -495,7+495,7 @@
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
@@ -610,7+610,7 @@
return {"stateless": True, "requires_fit": False}
def get_default_tags():
"""Returns the default tags of a Transformer.
Relies on the tags API of sklearn.
Specify tags in ``Transformer._more_tags``:
.. code-block:: py
def _more_tags(self):
return {"bob_input": "annotations"}
Retrieve all the tags with ``Transformer._get_tags``.
"""
return {
"bob_is_checkpointable": True,
"bob_is_daskable": True,
"bob_input": "data",
"bob_extra_input": [],
"bob_output": "data",
"bob_load_fct": bob.io.base.load,
"bob_save_fct": bob.io.base.save,
}
def wrap(bases, estimator=None, **kwargs):
"""Wraps several estimators inside each other.
@@ -690,6 +715,7 @@ def wrap(bases, estimator=None, **kwargs):
return estimator
def dask_tags(estimator):
"""Recursively collects resource_tags in dasked estimators."""
tags = {}
Loading