From 40331e70062e1a2c11d6d1580d03603a12c7ab9e Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 20 Nov 2020 00:36:39 +0100
Subject: [PATCH] [annotators][FailSafe] improvements

---
 bob/bio/base/annotator/FailSafe.py | 95 +++++++++++++++++-------------
 1 file changed, 53 insertions(+), 42 deletions(-)

diff --git a/bob/bio/base/annotator/FailSafe.py b/bob/bio/base/annotator/FailSafe.py
index f30dc012..40433866 100644
--- a/bob/bio/base/annotator/FailSafe.py
+++ b/bob/bio/base/annotator/FailSafe.py
@@ -1,17 +1,9 @@
 import logging
-import six
 from . import Annotator
 from .. import load_resource
 
 logger = logging.getLogger(__name__)
 
-def _isolate_kwargs(kwargs_dict, index):
-    """
-    Returns the kwargs to pass down to an annotator.
-
-    Each annotator is expecting a batch of samples and the corresponding kwargs.
-    """
-    return {k:[v[index]] for k,v in kwargs_dict.items()}
 
 class FailSafe(Annotator):
     """A fail-safe annotator.
@@ -30,18 +22,47 @@ class FailSafe(Annotator):
         If True, the annotations will only contain the ``required_keys``.
     """
 
-    def __init__(self, annotators, required_keys, only_required_keys=False,
-                 **kwargs):
+    def __init__(self, annotators, required_keys, only_required_keys=False, **kwargs):
         super(FailSafe, self).__init__(**kwargs)
         self.annotators = []
         for annotator in annotators:
-            if isinstance(annotator, six.string_types):
-                annotator = load_resource(annotator, 'annotator')
+            if isinstance(annotator, str):
+                annotator = load_resource(annotator, "annotator")
             self.annotators.append(annotator)
         self.required_keys = list(required_keys)
         self.only_required_keys = only_required_keys
 
-    def transform(self, sample_batch, **kwargs):
+    def annotate(self, sample, **kwargs):
+        if "annotations" not in kwargs or kwargs["annotations"] is None:
+            kwargs["annotations"] = {}
+        for annotator in self.annotators:
+            try:
+                annotations = annotator.transform(
+                    [sample], **{k: [v] for k, v in kwargs.items()}
+                )[0]
+            except Exception:
+                logger.debug(
+                    "The annotator `%s' failed to annotate!", annotator, exc_info=True
+                )
+                annotations = None
+            if not annotations:
+                logger.debug("Annotator `%s' returned empty annotations.", annotator)
+            else:
+                logger.debug("Annotator `%s' succeeded!", annotator)
+            kwargs["annotations"].update(annotations or {})
+            # check if we have all the required annotations
+            if all(key in kwargs["annotations"] for key in self.required_keys):
+                break
+        else:  # this else is for the for loop
+            # we don't want to return half of the annotations
+            kwargs["annotations"] = None
+        if self.only_required_keys:
+            for key in list(kwargs["annotations"].keys()):
+                if key not in self.required_keys:
+                    del kwargs["annotations"][key]
+        return kwargs["annotations"]
+
+    def transform(self, samples, **kwargs):
         """
         Takes a batch of data and tries annotating them while unsuccessful.
 
@@ -53,32 +74,22 @@ class FailSafe(Annotator):
         with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
         should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
         """
-        if 'annotations' not in kwargs or kwargs['annotations'] is None:
-            kwargs['annotations'] = [{}] * len(sample_batch)
-        # Executes on each sample and corresponding existing annotations
-        for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])):
-            for annotator in self.annotators:
-                try:
-                    annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0]
-                except Exception:
-                    logger.debug(
-                        "The annotator `%s' failed to annotate!", annotator,
-                        exc_info=True)
-                    annot = None
-                if not annot:
-                    logger.debug(
-                        "Annotator `%s' returned empty annotations.", annotator)
-                else:
-                    logger.debug("Annotator `%s' succeeded!", annotator)
-                annotations.update(annot or {})
-                # check if we have all the required annotations
-                if all(key in annotations for key in self.required_keys):
-                    break
-            else:  # this else is for the for loop
-                # we don't want to return half of the annotations
-                annotations = None
-            if self.only_required_keys:
-                for key in list(annotations.keys()):
-                    if key not in self.required_keys:
-                        del annotations[key]
-        return kwargs['annotations']
+        kwargs = translate_kwargs(kwargs, len(samples))
+        return [self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs)]
+
+
+def translate_kwargs(kwargs, size):
+    new_kwargs = [{}] * size
+
+    if not kwargs:
+        return new_kwargs
+
+    for k, value_list in kwargs.items():
+        if len(value_list) != size:
+            raise ValueError(
+                f"Got {value_list} in kwargs which is not of the same length of samples {size}"
+            )
+        for kw, v in zip(new_kwargs, value_list):
+            kw[k] = v
+
+    return new_kwargs
-- 
GitLab