Skip to content
Snippets Groups Projects
Commit 84983bac authored by Yannick DAYER's avatar Yannick DAYER
Browse files

[py] Adds annotations-related wrappers

parent 57bd3218
No related branches found
No related tags found
1 merge request!42Adding annotations-related wrappers
Pipeline #44957 failed with stage
in 1 hour, 34 minutes, and 25 seconds
......@@ -10,9 +10,11 @@ from .sample import hdf5_to_sample # noqa
from .sample import sample_to_hdf5 # noqa
from .wrappers import BaseWrapper
from .wrappers import CheckpointWrapper
from .wrappers import CheckpointAnnotationsWrapper
from .wrappers import DaskWrapper
from .wrappers import DelayedSamplesCall
from .wrappers import SampleWrapper
from .wrappers import AnnotatedSampleWrapper
from .wrappers import ToDaskBag
from .wrappers import dask_tags # noqa
from .wrappers import wrap # noqa
......
"""Scikit-learn Estimator Wrappers."""
import logging
import json
import os
from functools import partial
......@@ -44,6 +45,15 @@ def copy_learned_attributes(from_estimator, to_estimator):
setattr(to_estimator, k, v)
def json_dump(data, path):
with open(path, "w") as f:
json.dump(data, f)
def json_load(path):
with open(path, "r") as f:
return json.load(f)
class BaseWrapper(MetaEstimatorMixin, BaseEstimator):
"""The base class for all wrappers."""
......@@ -174,6 +184,74 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
return self
class AnnotatedSampleWrapper(SampleWrapper):
"""Wraps an annotator Transformer to set the sample.annotations correctly.
An :py:class:`~bob.bio.base.Annotator` transformer simply returns its
results, (or in the :py:attr:`~bob.pipelines.Sample.data` attribute when
wrapped with :py:class:`~bob.pipelines.SampleWrapper`).
Use this wrapper uniquely with Annotators, and INSTEAD of the
:py:class:`~bob.pipelines.SampleWrapper`.
Attributes
----------
fit_extra_arguments : [tuple]
Use this option if you want to pass extra arguments to the fit method of the
mixed instance. The format is a list of two value tuples. The first value in
tuples is the name of the argument that fit accepts, like ``y``, and the second
value is the name of the attribute that samples carry. For example, if you are
passing samples to the fit method and want to pass ``subject`` attributes of
samples as the ``y`` argument to the fit method, you can provide ``[("y",
"subject")]`` as the value for this attribute.
transform_extra_arguments : [tuple]
Similar to ``fit_extra_arguments`` but for the transform and other similar methods.
"""
def __init__(
self,
annotator,
transform_extra_arguments=None,
fit_extra_arguments=None,
**kwargs,
):
super().__init__(
estimator=annotator,
transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments,
**kwargs,
)
def _samples_transform(self, samples, method_name):
"""
Transforms a set of samples by calling the annotator with any method.
Overrides SampleWrapper.sample_transform to insert annotations in their
field (:py:attr:`~bob.pipelines.Sample.annotations`) instead of the
:py:attr:`~bob.pipelines.Sample.data` field.
"""
# Transform either samples or samplesets
method = getattr(self.estimator, method_name)
logger.debug(f"{_frmt(self)}.{method_name}")
func_name = f"{self}.{method_name}"
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name), parent=sset,
)
for sset in samples
]
else:
kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments)
delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples,)
new_samples = [
DelayedSample(load=s.load,annotations=partial(delayed, index=i)(), parent=s)
for i, s in enumerate(samples)
]
return new_samples
class CheckpointWrapper(BaseWrapper, TransformerMixin):
"""Wraps :any:`Sample`-based estimators so the results are saved in
disk."""
......@@ -315,6 +393,96 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
cloudpickle.dump(self, f)
return self
class CheckpointAnnotationsWrapper(CheckpointWrapper):
"""Wraps :any:`Sample`-based estimators so the annotations are saved to
disk."""
def __init__(
self,
annotator,
annotations_dir=None,
extension=".json",
save_func=None,
load_func=None,
force=False,
**kwargs,
):
save_func = save_func or json_dump
load_func = load_func or json_load
super().__init__(
estimator=annotator,
features_dir=annotations_dir,
extension=extension,
save_func=save_func,
load_func=load_func,
**kwargs,
)
self.force = force
def _checkpoint_transform(self, samples, method_name):
# Transform either samples or samplesets
method = getattr(self.estimator, method_name)
logger.debug(f"{_frmt(self)}.{method_name}")
# if features_dir is None, just transform all samples at once
if self.features_dir is None:
return method(samples)
def _transform_samples(samples):
paths = [self.make_path(s) for s in samples]
should_compute_list = [
p is None or not os.path.isfile(p) or self.force
for p in paths
]
# call method on non-checkpointed samples
non_existing_samples = [
s
for s, should_compute in zip(samples, should_compute_list)
if should_compute
]
# non_existing_samples could be empty
computed_features = []
if non_existing_samples:
computed_features = method(non_existing_samples)
_check_n_input_output(non_existing_samples, computed_features, method)
# return computed features and checkpointed features
features, com_feat_index = [], 0
for s, p, should_compute in zip(samples, paths, should_compute_list):
if should_compute:
feat = computed_features[com_feat_index]
com_feat_index += 1
# save the computed feature
if p is not None:
self.save(feat)
feat = self.load(s, p)
s.annotations = feat
else:
s.annotations=self.load(s, p)
return samples
if isinstance(samples[0], SampleSet):
return [SampleSet(_transform_samples(s.samples), parent=s) for s in samples]
else:
return _transform_samples(samples)
def save(self, sample):
"""
Saves a sample's annotations to disk using self.save_func.
Overrides CheckpointAnnotations.save
"""
path = self.make_path(sample)
os.makedirs(os.path.dirname(path), exist_ok=True)
return self.save_func(sample.annotations, path)
def load(self, sample, path):
"""
Loads a sample's annotations from disk using self.load_func.
Overrides CheckpointAnnotations.load
"""
return self.load_func(path)
class DaskWrapper(BaseWrapper, TransformerMixin):
"""Wraps Scikit estimators to handle Dask Bags as input.
......@@ -459,6 +627,8 @@ def wrap(bases, estimator=None, **kwargs):
"sample": SampleWrapper,
"checkpoint": CheckpointWrapper,
"dask": DaskWrapper,
"annotated_sample": AnnotatedSampleWrapper,
"checkpoint_annotations": CheckpointAnnotationsWrapper,
}[w.lower()]
def _wrap(estimator, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment