diff --git a/bob/pipelines/__init__.py b/bob/pipelines/__init__.py index baba90ef0d3bb3301ad7635db3691b48a80e19ac..1cc56dbfe908c85a5b4fc59371994edf3deacaca 100644 --- a/bob/pipelines/__init__.py +++ b/bob/pipelines/__init__.py @@ -10,9 +10,11 @@ from .sample import hdf5_to_sample # noqa from .sample import sample_to_hdf5 # noqa from .wrappers import BaseWrapper from .wrappers import CheckpointWrapper +from .wrappers import CheckpointAnnotationsWrapper from .wrappers import DaskWrapper from .wrappers import DelayedSamplesCall from .wrappers import SampleWrapper +from .wrappers import AnnotatedSampleWrapper from .wrappers import ToDaskBag from .wrappers import dask_tags # noqa from .wrappers import wrap # noqa diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index 0297c78c97057ef3f0e39d2e34a100920b08919a..bc80c9c5f3c43c85cb9c9cd490c5af0bd663c2e2 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -1,5 +1,6 @@ """Scikit-learn Estimator Wrappers.""" import logging +import json import os from functools import partial @@ -44,6 +45,15 @@ def copy_learned_attributes(from_estimator, to_estimator): setattr(to_estimator, k, v) +def json_dump(data, path): + with open(path, "w") as f: + json.dump(data, f) + +def json_load(path): + with open(path, "r") as f: + return json.load(f) + + class BaseWrapper(MetaEstimatorMixin, BaseEstimator): """The base class for all wrappers.""" @@ -174,6 +184,74 @@ class SampleWrapper(BaseWrapper, TransformerMixin): return self +class AnnotatedSampleWrapper(SampleWrapper): + """Wraps an annotator Transformer to set the sample.annotations correctly. + + An :py:class:`~bob.bio.base.Annotator` transformer simply returns its + results, (or in the :py:attr:`~bob.pipelines.Sample.data` attribute when + wrapped with :py:class:`~bob.pipelines.SampleWrapper`). + + Use this wrapper uniquely with Annotators, and INSTEAD of the + :py:class:`~bob.pipelines.SampleWrapper`. + + Attributes + ---------- + fit_extra_arguments : [tuple] + Use this option if you want to pass extra arguments to the fit method of the + mixed instance. The format is a list of two value tuples. The first value in + tuples is the name of the argument that fit accepts, like ``y``, and the second + value is the name of the attribute that samples carry. For example, if you are + passing samples to the fit method and want to pass ``subject`` attributes of + samples as the ``y`` argument to the fit method, you can provide ``[("y", + "subject")]`` as the value for this attribute. + transform_extra_arguments : [tuple] + Similar to ``fit_extra_arguments`` but for the transform and other similar methods. + """ + + def __init__( + self, + annotator, + transform_extra_arguments=None, + fit_extra_arguments=None, + **kwargs, + ): + super().__init__( + estimator=annotator, + transform_extra_arguments=transform_extra_arguments, + fit_extra_arguments=fit_extra_arguments, + **kwargs, + ) + + def _samples_transform(self, samples, method_name): + """ + Transforms a set of samples by calling the annotator with any method. + + Overrides SampleWrapper.sample_transform to insert annotations in their + field (:py:attr:`~bob.pipelines.Sample.annotations`) instead of the + :py:attr:`~bob.pipelines.Sample.data` field. + """ + # Transform either samples or samplesets + method = getattr(self.estimator, method_name) + logger.debug(f"{_frmt(self)}.{method_name}") + func_name = f"{self}.{method_name}" + + if isinstance(samples[0], SampleSet): + return [ + SampleSet( + self._samples_transform(sset.samples, method_name), parent=sset, + ) + for sset in samples + ] + else: + kwargs = _make_kwargs_from_samples(samples, self.transform_extra_arguments) + delayed = DelayedSamplesCall(partial(method, **kwargs), func_name, samples,) + new_samples = [ + DelayedSample(load=s.load,annotations=partial(delayed, index=i)(), parent=s) + for i, s in enumerate(samples) + ] + return new_samples + + class CheckpointWrapper(BaseWrapper, TransformerMixin): """Wraps :any:`Sample`-based estimators so the results are saved in disk.""" @@ -315,6 +393,96 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): cloudpickle.dump(self, f) return self +class CheckpointAnnotationsWrapper(CheckpointWrapper): + """Wraps :any:`Sample`-based estimators so the annotations are saved to + disk.""" + + def __init__( + self, + annotator, + annotations_dir=None, + extension=".json", + save_func=None, + load_func=None, + force=False, + **kwargs, + ): + save_func = save_func or json_dump + load_func = load_func or json_load + super().__init__( + estimator=annotator, + features_dir=annotations_dir, + extension=extension, + save_func=save_func, + load_func=load_func, + **kwargs, + ) + self.force = force + + def _checkpoint_transform(self, samples, method_name): + # Transform either samples or samplesets + method = getattr(self.estimator, method_name) + logger.debug(f"{_frmt(self)}.{method_name}") + + # if features_dir is None, just transform all samples at once + if self.features_dir is None: + return method(samples) + + def _transform_samples(samples): + paths = [self.make_path(s) for s in samples] + should_compute_list = [ + p is None or not os.path.isfile(p) or self.force + for p in paths + ] + # call method on non-checkpointed samples + non_existing_samples = [ + s + for s, should_compute in zip(samples, should_compute_list) + if should_compute + ] + # non_existing_samples could be empty + computed_features = [] + if non_existing_samples: + computed_features = method(non_existing_samples) + _check_n_input_output(non_existing_samples, computed_features, method) + # return computed features and checkpointed features + features, com_feat_index = [], 0 + for s, p, should_compute in zip(samples, paths, should_compute_list): + if should_compute: + feat = computed_features[com_feat_index] + com_feat_index += 1 + # save the computed feature + if p is not None: + self.save(feat) + feat = self.load(s, p) + s.annotations = feat + else: + s.annotations=self.load(s, p) + return samples + + if isinstance(samples[0], SampleSet): + return [SampleSet(_transform_samples(s.samples), parent=s) for s in samples] + else: + return _transform_samples(samples) + + def save(self, sample): + """ + Saves a sample's annotations to disk using self.save_func. + + Overrides CheckpointAnnotations.save + """ + path = self.make_path(sample) + os.makedirs(os.path.dirname(path), exist_ok=True) + return self.save_func(sample.annotations, path) + + def load(self, sample, path): + """ + Loads a sample's annotations from disk using self.load_func. + + Overrides CheckpointAnnotations.load + """ + return self.load_func(path) + class DaskWrapper(BaseWrapper, TransformerMixin): """Wraps Scikit estimators to handle Dask Bags as input. @@ -459,6 +627,8 @@ def wrap(bases, estimator=None, **kwargs): "sample": SampleWrapper, "checkpoint": CheckpointWrapper, "dask": DaskWrapper, + "annotated_sample": AnnotatedSampleWrapper, + "checkpoint_annotations": CheckpointAnnotationsWrapper, }[w.lower()] def _wrap(estimator, **kwargs):