Commit 43fd5280 authored by Yannick DAYER's avatar Yannick DAYER

Implement `get_default_tags` in wrappers

parent ca5f2851
......@@ -18,6 +18,7 @@ from .wrappers import SampleWrapper
from .wrappers import ToDaskBag
from .wrappers import dask_tags # noqa: F401
from .wrappers import wrap # noqa: F401
from .wrappers import get_default_tags # noqa: F401
def __appropriate__(*args):
......
......@@ -568,6 +568,31 @@ class ToDaskBag(TransformerMixin, BaseEstimator):
return {"stateless": True, "requires_fit": False}
def get_default_tags():
"""Returns the default tags of a Transformer.
Relies on the tags API of sklearn.
Specify tags in ``Transformer._more_tags``:
.. code-block:: py
def _more_tags(self):
return {"bob_input": "annotations"}
Retrieve all the tags with ``Transformer._get_tags``.
"""
return {
"bob_is_checkpointable": True,
"bob_is_daskable": True,
"bob_input": "data",
"bob_extra_input": [],
"bob_output": "data",
"bob_load_fct": bob.io.base.load,
"bob_save_fct": bob.io.base.save,
}
def wrap(bases, estimator=None, **kwargs):
"""Wraps several estimators inside each other.
......@@ -646,6 +671,7 @@ def wrap(bases, estimator=None, **kwargs):
return estimator
def dask_tags(estimator):
"""Recursively collects resource_tags in dasked estimators."""
tags = {}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment