Commit dbd5dace authored by Yannick DAYER's avatar Yannick DAYER

Tags definition

parent 43fd5280
......@@ -18,7 +18,6 @@ from .wrappers import SampleWrapper
from .wrappers import ToDaskBag
from .wrappers import dask_tags # noqa: F401
from .wrappers import wrap # noqa: F401
from .wrappers import get_default_tags # noqa: F401
def __appropriate__(*args):
......
......@@ -232,16 +232,24 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
Saves the transformed data in this directory
extension: str
Default extension of the transformed features
Default extension of the transformed features.
If None, will use the ``bob_checkpoint_extension`` tag in the estimator, and
default to ``.h5`` if needed.
save_func
Pointer to a customized function that saves transformed features to disk
Pointer to a customized function that saves transformed features to disk.
If None, will use the ``bob_feature_save_fn`` tag in the estimator, and default
to ``bob.io.base.save`` if needed.
load_func
Pointer to a customized function that loads transformed features from disk
Pointer to a customized function that loads transformed features from disk.
If None, will use the ``bob_feature_load_fn`` tag in the estimator, and default
to ``bob.io.base.load`` if needed.
sample_attribute: str
Defines the payload attribute of the sample (Defaul: `data`)
Defines the payload attribute of the sample.
If None, will use the ``bob_output`` tag in the estimator, and default to
``data`` if needed.
hash_fn
Pointer to a hash function. This hash function maps
......@@ -262,10 +270,10 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
estimator,
model_path=None,
features_dir=None,
extension=".h5",
extension=None,
save_func=None,
load_func=None,
sample_attribute="data",
sample_attribute=None,
hash_fn=None,
attempts=10,
**kwargs,
......@@ -274,18 +282,18 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.estimator = estimator
self.model_path = model_path
self.features_dir = features_dir
self.extension = extension
self.save_func = (
save_func
or estimator._get_tags().get("bob_features_save_fn")
or bob.io.base.save
self.extension = extension or estimator._get_tags().get(
"bob_checkpoint_extension", ".h5"
)
self.load_func = (
load_func
or estimator._get_tags().get("bob_features_load_fn")
or bob.io.base.load
self.save_func = save_func or estimator._get_tags().get(
"bob_features_save_fn", bob.io.base.save
)
self.load_func = load_func or estimator._get_tags().get(
"bob_features_load_fn", bob.io.base.load
)
self.sample_attribute = sample_attribute or estimator._get_tags().get(
"bob_checkpoint_attribute", "data"
)
self.sample_attribute = sample_attribute
self.hash_fn = hash_fn
self.attempts = attempts
if model_path is None and features_dir is None:
......@@ -568,40 +576,85 @@ class ToDaskBag(TransformerMixin, BaseEstimator):
return {"stateless": True, "requires_fit": False}
def get_default_tags():
"""Returns the default tags of a Transformer.
Relies on the tags API of sklearn.
def get_default_bob_tags(estimator_tags={}):
"""Returns the default tags of a Transformer unless specified.
Specify tags in ``Transformer._more_tags``:
Relies on the tags API of sklearn to set and retrieve the tags.
.. code-block:: py
Specify tags values in ``Transformer._more_tags``:
```
class My_annotator_transformer(sklearn.base.BaseEstimator):
def _more_tags(self):
return {"bob_input": "annotations"}
return {"bob_output": "annotations"}
```
Retrieve all the tags with ``Transformer._get_tags``.
Parameters
----------
set_tags: dict[str, typing.Any]
Tags with a non-default value.
"""
wrap_default_tags = {
"bob_is_checkpointable": True, # Used to skip the wrapping of checkpoint
"bob_is_daskable": True, # Used to skip the dask wrapping
"bob_sample_input": False, # Used to skip the sample wrapping
}
samplewrapper_default_tags = {
"bob_input": "data", # Selects Which field of Sample is fed to transform
"bob_extra_input": {}, # Specifies additional fields input to transform
"bob_extra_fit_input": {}, # Specifies additional fields input to fit
"bob_output": "data", # Output field to save when checkpointing
}
checkpointwrapper_default_tags = {
"bob_checkpoint_extension": ".h5",
"bob_features_save_fn": bob.io.base.save, # Function used to checkpoint
"bob_features_load_fn": bob.io.base.load, # Function used to restore from checkpoint
}
return {
"bob_is_checkpointable": True,
"bob_is_daskable": True,
"bob_input": "data",
"bob_extra_input": [],
"bob_output": "data",
"bob_load_fct": bob.io.base.load,
"bob_save_fct": bob.io.base.save,
**wrap_default_tags,
**samplewrapper_default_tags,
**checkpointwrapper_default_tags,
**estimator_tags,
}
def needs_wrap(estimator, wrapper):
if estimator is None:
return True
bob_tags = get_default_bob_tags(estimator._get_tags())
return (
wrapper is CheckpointWrapper
and not bob_tags.get("bob_is_checkpointable")
or (wrapper is DaskWrapper and not bob_tags.get("bob_is_daskable"))
or (wrapper is SampleWrapper and bob_tags.get("bob_input_samples"))
)
def wrap(bases, estimator=None, **kwargs):
"""Wraps several estimators inside each other.
If ``estimator`` is a pipeline, the estimators in that pipeline are wrapped.
Use estimator tags to wrap conditionally, or to pass special variables to an
estimator. Processed tags are:
- *bob_is_checkpointable* Default: True; Skips the checkpoint wrapping if False.
- *bob_is_daskable* Default: True; Skips the dask wrapping if False.
- *bob_sample_input* Default: False; The transformer takes samples as input if True.
- *bob_input* Default: "data"; Selects Which field of Sample is fed to transform.
- *bob_extra_input* Default: []; Specifies additional fields input to transform.
- *bob_extra_fit_input* Default: []; Specifies additional fields input to fit.
- *bob_output* Default: "data"; Output field to save when checkpointing.
- *bob_save_fct* Default: bob.io.base.save; Function used to checkpoint.
- *bob_load_fct* Default: bob.io.base.load; Function used to restore from checkpoint.
Parameters
----------
bases : list
A list of classes to be used
A list of classes to be used to wrap ``estimator``.
estimator : :any:`object`, optional
An initial estimator to be wrapped inside other wrappers. If None, the first class will be used to initialize the estimator.
An initial estimator to be wrapped inside other wrappers.
If None, the first class will be used to initialize the estimator.
**kwargs
Extra parameters passed to the init of classes.
......@@ -631,7 +684,7 @@ def wrap(bases, estimator=None, **kwargs):
params = {k: kwargs.pop(k) for k in valid_params if k in kwargs}
if estimator is None:
estimator = w_class(**params)
else:
elif needs_wrap(estimator, w_class):
estimator = w_class(estimator, **params)
return estimator, kwargs
......@@ -671,7 +724,6 @@ def wrap(bases, estimator=None, **kwargs):
return estimator
def dask_tags(estimator):
"""Recursively collects resource_tags in dasked estimators."""
tags = {}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment