Commit d09bc236 authored by Yannick DAYER's avatar Yannick DAYER

Finalize tags, add SampleWrapper tags, add test

parent b2a93711
Pipeline #52710 failed with stage
in 10 minutes and 51 seconds
......@@ -92,6 +92,13 @@ class FullFailingDummyTransformer(DummyTransformer):
return [None] * len(X)
class DummyWithTags(DummyTransformer):
def _more_tags(self):
return {
"bob_output": "annotations",
}
def _assert_all_close_numpy_array(oracle, result):
oracle, result = np.array(oracle), np.array(result)
assert (
......@@ -138,6 +145,21 @@ def test_fittable_sample_transformer():
_assert_all_close_numpy_array(X + 1, [s.data for s in features])
def test_tagged_sample_transformer():
X = np.ones(shape=(10, 2), dtype=int)
samples = [mario.Sample(data) for data in X]
# Mixing up with an object
transformer = mario.wrap([DummyWithTags, "sample"])
features = transformer.transform(samples)
_assert_all_close_numpy_array(X + 1, [s.annotations for s in features])
_assert_all_close_numpy_array(X, [s.data for s in features])
# TODO add more tags tests
def test_failing_sample_transformer():
X = np.zeros(shape=(10, 2))
......
......@@ -53,6 +53,73 @@ def copy_learned_attributes(from_estimator, to_estimator):
setattr(to_estimator, k, v)
def get_bob_tags(estimator=None, force_tags={}):
"""Returns the default tags of a Transformer unless forced or specified.
Relies on the tags API of sklearn to set and retrieve the tags.
Specify an estimator tags values with ``estimator._more_tags``:
```
class My_annotator_transformer(sklearn.base.BaseEstimator):
def _more_tags(self):
return {"bob_output": "annotations"}
```
The returned tags will take their value with the following priority:
1. key:value in `force_tags`, if it is present;
2. key:value in `estimator` tags (set with `estimator._more_tags()`) if it exists;
3. the default value for that tag if none of the previous exist.
Tags format
-----------
bob_transform_input: tuple
- The first element is a string representing the Sample's attribute that will
be used as input.
- The following (optional) elements are tuples in the form:
(transform kwarg name, Sample attribute name)
Example: ("data", ("extra_arg", "annotations"),)
bob_fit_extra_input: tuple
Each element is a tuple of the form:
(transform kwarg name, Sample attribute name)
Example: (("y","annotations"), ("extra_arg", "metadata"))
bob_output: Str
The Sample attribute in which the output of the transform is stored.
bob_checkpoint_extension: Str
The extension of each checkpoint file.
bob_features_save_fn: func
The function used to save each checkpoint file.
bob_features_load_fn: func
The function used to load each checkpoint file.
Parameters
----------
estimator: sklearn.BaseEstimator
An estimator class with tags that will overwrite the default values.
force_tags: dict[str, typing.Any]
Tags with a non-default value that will overwrite the default and the estimator
tags.
Returns
-------
dict[str, Any]
The resulting tags with a value (either specified, forced, or default)
"""
default_tags = {
"bob_transform_input": (
"data",
), # Selects which fields of a Sample is fed to transform
"bob_fit_extra_input": tuple(), # Selects which fields of a Sample is fed to fit
"bob_output": "data", # Sample's destination field of the transformer
"bob_checkpoint_extension": ".h5",
"bob_features_save_fn": bob.io.base.save, # Function used to checkpoint
"bob_features_load_fn": bob.io.base.load, # Function used to restore
}
estimator_tags = estimator._get_tags()
return {**default_tags, **estimator_tags, **force_tags}
class BaseWrapper(MetaEstimatorMixin, BaseEstimator):
"""The base class for all wrappers."""
......@@ -133,16 +200,24 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
estimator,
transform_extra_arguments=None,
fit_extra_arguments=None,
output_attribute="data",
input_attribute="data",
output_attribute=None,
input_attribute=None,
**kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
self.transform_extra_arguments = transform_extra_arguments or tuple()
self.fit_extra_arguments = fit_extra_arguments or tuple()
self.output_attribute = output_attribute
self.input_attribute = input_attribute
# Tagged parameters
tags = get_bob_tags(estimator=estimator)
self.input_attribute = input_attribute or tags["bob_transform_input"][0]
self.transform_extra_arguments = transform_extra_arguments or (
tags["bob_transform_input"][1:]
if len(tags["bob_transform_input"]) > 1
else tuple()
)
self.fit_extra_arguments = fit_extra_arguments or tags["bob_fit_extra_input"]
self.output_attribute = output_attribute or tags["bob_output"]
def _samples_transform(self, samples, method_name):
# Transform either samples or samplesets
......@@ -233,23 +308,23 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
extension: str
Default extension of the transformed features.
If None, will use the ``bob_checkpoint_extension`` tag in the estimator, and
default to ``.h5`` if needed.
If None, will use the ``bob_checkpoint_extension`` tag in the estimator, or
default to ``.h5``.
save_func
Pointer to a customized function that saves transformed features to disk.
If None, will use the ``bob_feature_save_fn`` tag in the estimator, and default
to ``bob.io.base.save`` if needed.
If None, will use the ``bob_feature_save_fn`` tag in the estimator, or default
to ``bob.io.base.save``.
load_func
Pointer to a customized function that loads transformed features from disk.
If None, will use the ``bob_feature_load_fn`` tag in the estimator, and default
to ``bob.io.base.load`` if needed.
If None, will use the ``bob_feature_load_fn`` tag in the estimator, or default
to ``bob.io.base.load``.
sample_attribute: str
Defines the payload attribute of the sample.
If None, will use the ``bob_output`` tag in the estimator, and default to
``data`` if needed.
If None, will use the ``bob_output`` tag in the estimator, or default to
``data``.
hash_fn
Pointer to a hash function. This hash function maps
......@@ -282,20 +357,27 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
self.estimator = estimator
self.model_path = model_path
self.features_dir = features_dir
self.extension = extension or estimator._get_tags().get(
"bob_checkpoint_extension", ".h5"
)
self.save_func = save_func or estimator._get_tags().get(
"bob_features_save_fn", bob.io.base.save
)
self.load_func = load_func or estimator._get_tags().get(
"bob_features_load_fn", bob.io.base.load
)
self.sample_attribute = sample_attribute or estimator._get_tags().get(
"bob_output", "data"
)
self.hash_fn = hash_fn
self.attempts = attempts
# Tagged parameters
forced_tags = dict()
if extension is not None:
forced_tags["bob_checkpoint_extension"] = extension
if save_func is not None:
forced_tags["bob_features_save_fn"] = save_func
if load_func is not None:
forced_tags["bob_features_load_fn"] = load_func
if sample_attribute is not None:
forced_tags["bob_output"] = sample_attribute
estimator_tags = get_bob_tags(estimator=estimator, force_tags=forced_tags)
self.extension = estimator_tags["bob_checkpoint_extension"]
self.save_func = estimator_tags["bob_features_save_fn"]
self.load_func = estimator_tags["bob_features_load_fn"]
self.sample_attribute = estimator_tags["bob_output"]
# Paths check
if model_path is None and features_dir is None:
logger.warning(
"Both model_path and features_dir are None. "
......@@ -576,58 +658,11 @@ class ToDaskBag(TransformerMixin, BaseEstimator):
return {"stateless": True, "requires_fit": False}
def get_default_bob_tags(estimator_tags={}):
"""Returns the default tags of a Transformer unless specified.
Relies on the tags API of sklearn to set and retrieve the tags.
Specify tags values in ``Transformer._more_tags``:
```
class My_annotator_transformer(sklearn.base.BaseEstimator):
def _more_tags(self):
return {"bob_output": "annotations"}
```
Parameters
----------
set_tags: dict[str, typing.Any]
Tags with a non-default value that will be overwritten in returned tags.
Returns
-------
dict[str, Any]
The tags specified in `estimator_tags`, and all the default bob tags.
"""
samplewrapper_default_tags = {
"bob_input": ("data",), # Selects Which fields of a Sample is fed to transform
"bob_output": "data", # Output field of a transformer
}
checkpointwrapper_default_tags = {
"bob_checkpoint_extension": ".h5",
"bob_features_save_fn": bob.io.base.save, # Function used to checkpoint
"bob_features_load_fn": bob.io.base.load, # Function used to restore
}
return {
**samplewrapper_default_tags,
**checkpointwrapper_default_tags,
**estimator_tags,
}
def wrap(bases, estimator=None, **kwargs):
"""Wraps several estimators inside each other.
If ``estimator`` is a pipeline, the estimators in that pipeline are wrapped.
Use estimator tags to wrap conditionally, or to pass special variables to an
estimator. Processed tags are:
- ``bob_input`` Default: ("data",); Selects Which field of Sample is fed to transform.
- ``bob_output`` Default: "data"; Output field (saved when checkpointing).
- ``bob_save_fct`` Default: bob.io.base.save; Function used to checkpoint.
- ``bob_load_fct`` Default: bob.io.base.load; Function used to restore from checkpoint.
Parameters
----------
bases : list
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment