diff --git a/bob/bio/base/annotator/FailSafe.py b/bob/bio/base/annotator/FailSafe.py index f30dc01254847c85a60aeed14bf7e54557c460b8..4043386616cd6af59284b04d312df16ca4c85c45 100644 --- a/bob/bio/base/annotator/FailSafe.py +++ b/bob/bio/base/annotator/FailSafe.py @@ -1,17 +1,9 @@ import logging -import six from . import Annotator from .. import load_resource logger = logging.getLogger(__name__) -def _isolate_kwargs(kwargs_dict, index): - """ - Returns the kwargs to pass down to an annotator. - - Each annotator is expecting a batch of samples and the corresponding kwargs. - """ - return {k:[v[index]] for k,v in kwargs_dict.items()} class FailSafe(Annotator): """A fail-safe annotator. @@ -30,18 +22,47 @@ class FailSafe(Annotator): If True, the annotations will only contain the ``required_keys``. """ - def __init__(self, annotators, required_keys, only_required_keys=False, - **kwargs): + def __init__(self, annotators, required_keys, only_required_keys=False, **kwargs): super(FailSafe, self).__init__(**kwargs) self.annotators = [] for annotator in annotators: - if isinstance(annotator, six.string_types): - annotator = load_resource(annotator, 'annotator') + if isinstance(annotator, str): + annotator = load_resource(annotator, "annotator") self.annotators.append(annotator) self.required_keys = list(required_keys) self.only_required_keys = only_required_keys - def transform(self, sample_batch, **kwargs): + def annotate(self, sample, **kwargs): + if "annotations" not in kwargs or kwargs["annotations"] is None: + kwargs["annotations"] = {} + for annotator in self.annotators: + try: + annotations = annotator.transform( + [sample], **{k: [v] for k, v in kwargs.items()} + )[0] + except Exception: + logger.debug( + "The annotator `%s' failed to annotate!", annotator, exc_info=True + ) + annotations = None + if not annotations: + logger.debug("Annotator `%s' returned empty annotations.", annotator) + else: + logger.debug("Annotator `%s' succeeded!", annotator) + kwargs["annotations"].update(annotations or {}) + # check if we have all the required annotations + if all(key in kwargs["annotations"] for key in self.required_keys): + break + else: # this else is for the for loop + # we don't want to return half of the annotations + kwargs["annotations"] = None + if self.only_required_keys: + for key in list(kwargs["annotations"].keys()): + if key not in self.required_keys: + del kwargs["annotations"][key] + return kwargs["annotations"] + + def transform(self, samples, **kwargs): """ Takes a batch of data and tries annotating them while unsuccessful. @@ -53,32 +74,22 @@ class FailSafe(Annotator): with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']`` should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``). """ - if 'annotations' not in kwargs or kwargs['annotations'] is None: - kwargs['annotations'] = [{}] * len(sample_batch) - # Executes on each sample and corresponding existing annotations - for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])): - for annotator in self.annotators: - try: - annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0] - except Exception: - logger.debug( - "The annotator `%s' failed to annotate!", annotator, - exc_info=True) - annot = None - if not annot: - logger.debug( - "Annotator `%s' returned empty annotations.", annotator) - else: - logger.debug("Annotator `%s' succeeded!", annotator) - annotations.update(annot or {}) - # check if we have all the required annotations - if all(key in annotations for key in self.required_keys): - break - else: # this else is for the for loop - # we don't want to return half of the annotations - annotations = None - if self.only_required_keys: - for key in list(annotations.keys()): - if key not in self.required_keys: - del annotations[key] - return kwargs['annotations'] + kwargs = translate_kwargs(kwargs, len(samples)) + return [self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs)] + + +def translate_kwargs(kwargs, size): + new_kwargs = [{}] * size + + if not kwargs: + return new_kwargs + + for k, value_list in kwargs.items(): + if len(value_list) != size: + raise ValueError( + f"Got {value_list} in kwargs which is not of the same length of samples {size}" + ) + for kw, v in zip(new_kwargs, value_list): + kw[k] = v + + return new_kwargs