Skip to content
Snippets Groups Projects
Commit 40331e70 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[annotators][FailSafe] improvements

parent d7470dfe
Branches
No related tags found
1 merge request!210[annotators][FailSafe] improvements
Pipeline #45878 passed
import logging import logging
import six
from . import Annotator from . import Annotator
from .. import load_resource from .. import load_resource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _isolate_kwargs(kwargs_dict, index):
"""
Returns the kwargs to pass down to an annotator.
Each annotator is expecting a batch of samples and the corresponding kwargs.
"""
return {k:[v[index]] for k,v in kwargs_dict.items()}
class FailSafe(Annotator): class FailSafe(Annotator):
"""A fail-safe annotator. """A fail-safe annotator.
...@@ -30,18 +22,47 @@ class FailSafe(Annotator): ...@@ -30,18 +22,47 @@ class FailSafe(Annotator):
If True, the annotations will only contain the ``required_keys``. If True, the annotations will only contain the ``required_keys``.
""" """
def __init__(self, annotators, required_keys, only_required_keys=False, def __init__(self, annotators, required_keys, only_required_keys=False, **kwargs):
**kwargs):
super(FailSafe, self).__init__(**kwargs) super(FailSafe, self).__init__(**kwargs)
self.annotators = [] self.annotators = []
for annotator in annotators: for annotator in annotators:
if isinstance(annotator, six.string_types): if isinstance(annotator, str):
annotator = load_resource(annotator, 'annotator') annotator = load_resource(annotator, "annotator")
self.annotators.append(annotator) self.annotators.append(annotator)
self.required_keys = list(required_keys) self.required_keys = list(required_keys)
self.only_required_keys = only_required_keys self.only_required_keys = only_required_keys
def transform(self, sample_batch, **kwargs): def annotate(self, sample, **kwargs):
if "annotations" not in kwargs or kwargs["annotations"] is None:
kwargs["annotations"] = {}
for annotator in self.annotators:
try:
annotations = annotator.transform(
[sample], **{k: [v] for k, v in kwargs.items()}
)[0]
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator, exc_info=True
)
annotations = None
if not annotations:
logger.debug("Annotator `%s' returned empty annotations.", annotator)
else:
logger.debug("Annotator `%s' succeeded!", annotator)
kwargs["annotations"].update(annotations or {})
# check if we have all the required annotations
if all(key in kwargs["annotations"] for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
kwargs["annotations"] = None
if self.only_required_keys:
for key in list(kwargs["annotations"].keys()):
if key not in self.required_keys:
del kwargs["annotations"][key]
return kwargs["annotations"]
def transform(self, samples, **kwargs):
""" """
Takes a batch of data and tries annotating them while unsuccessful. Takes a batch of data and tries annotating them while unsuccessful.
...@@ -53,32 +74,22 @@ class FailSafe(Annotator): ...@@ -53,32 +74,22 @@ class FailSafe(Annotator):
with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']`` with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``). should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
""" """
if 'annotations' not in kwargs or kwargs['annotations'] is None: kwargs = translate_kwargs(kwargs, len(samples))
kwargs['annotations'] = [{}] * len(sample_batch) return [self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs)]
# Executes on each sample and corresponding existing annotations
for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])):
for annotator in self.annotators: def translate_kwargs(kwargs, size):
try: new_kwargs = [{}] * size
annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0]
except Exception: if not kwargs:
logger.debug( return new_kwargs
"The annotator `%s' failed to annotate!", annotator,
exc_info=True) for k, value_list in kwargs.items():
annot = None if len(value_list) != size:
if not annot: raise ValueError(
logger.debug( f"Got {value_list} in kwargs which is not of the same length of samples {size}"
"Annotator `%s' returned empty annotations.", annotator) )
else: for kw, v in zip(new_kwargs, value_list):
logger.debug("Annotator `%s' succeeded!", annotator) kw[k] = v
annotations.update(annot or {})
# check if we have all the required annotations return new_kwargs
if all(key in annotations for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
annotations = None
if self.only_required_keys:
for key in list(annotations.keys()):
if key not in self.required_keys:
del annotations[key]
return kwargs['annotations']
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment