Skip to content
Snippets Groups Projects
Commit 40331e70 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[annotators][FailSafe] improvements

parent d7470dfe
Branches
No related tags found
1 merge request!210[annotators][FailSafe] improvements
Pipeline #45878 passed
import logging import logging
import six
from . import Annotator from . import Annotator
from .. import load_resource from .. import load_resource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _isolate_kwargs(kwargs_dict, index):
"""
Returns the kwargs to pass down to an annotator.
Each annotator is expecting a batch of samples and the corresponding kwargs.
"""
return {k:[v[index]] for k,v in kwargs_dict.items()}
class FailSafe(Annotator): class FailSafe(Annotator):
"""A fail-safe annotator. """A fail-safe annotator.
...@@ -30,55 +22,74 @@ class FailSafe(Annotator): ...@@ -30,55 +22,74 @@ class FailSafe(Annotator):
If True, the annotations will only contain the ``required_keys``. If True, the annotations will only contain the ``required_keys``.
""" """
def __init__(self, annotators, required_keys, only_required_keys=False, def __init__(self, annotators, required_keys, only_required_keys=False, **kwargs):
**kwargs):
super(FailSafe, self).__init__(**kwargs) super(FailSafe, self).__init__(**kwargs)
self.annotators = [] self.annotators = []
for annotator in annotators: for annotator in annotators:
if isinstance(annotator, six.string_types): if isinstance(annotator, str):
annotator = load_resource(annotator, 'annotator') annotator = load_resource(annotator, "annotator")
self.annotators.append(annotator) self.annotators.append(annotator)
self.required_keys = list(required_keys) self.required_keys = list(required_keys)
self.only_required_keys = only_required_keys self.only_required_keys = only_required_keys
def transform(self, sample_batch, **kwargs): def annotate(self, sample, **kwargs):
""" if "annotations" not in kwargs or kwargs["annotations"] is None:
Takes a batch of data and tries annotating them while unsuccessful. kwargs["annotations"] = {}
Tries each annotator given at the creation of FailSafe when the previous
one fails.
Each ``kwargs`` value is a list of parameters, with each element of those
lists corresponding to each element of ``sample_batch`` (for example:
with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
"""
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = [{}] * len(sample_batch)
# Executes on each sample and corresponding existing annotations
for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])):
for annotator in self.annotators: for annotator in self.annotators:
try: try:
annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0] annotations = annotator.transform(
[sample], **{k: [v] for k, v in kwargs.items()}
)[0]
except Exception: except Exception:
logger.debug( logger.debug(
"The annotator `%s' failed to annotate!", annotator, "The annotator `%s' failed to annotate!", annotator, exc_info=True
exc_info=True) )
annot = None annotations = None
if not annot: if not annotations:
logger.debug( logger.debug("Annotator `%s' returned empty annotations.", annotator)
"Annotator `%s' returned empty annotations.", annotator)
else: else:
logger.debug("Annotator `%s' succeeded!", annotator) logger.debug("Annotator `%s' succeeded!", annotator)
annotations.update(annot or {}) kwargs["annotations"].update(annotations or {})
# check if we have all the required annotations # check if we have all the required annotations
if all(key in annotations for key in self.required_keys): if all(key in kwargs["annotations"] for key in self.required_keys):
break break
else: # this else is for the for loop else: # this else is for the for loop
# we don't want to return half of the annotations # we don't want to return half of the annotations
annotations = None kwargs["annotations"] = None
if self.only_required_keys: if self.only_required_keys:
for key in list(annotations.keys()): for key in list(kwargs["annotations"].keys()):
if key not in self.required_keys: if key not in self.required_keys:
del annotations[key] del kwargs["annotations"][key]
return kwargs['annotations'] return kwargs["annotations"]
def transform(self, samples, **kwargs):
"""
Takes a batch of data and tries annotating them while unsuccessful.
Tries each annotator given at the creation of FailSafe when the previous
one fails.
Each ``kwargs`` value is a list of parameters, with each element of those
lists corresponding to each element of ``sample_batch`` (for example:
with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
"""
kwargs = translate_kwargs(kwargs, len(samples))
return [self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs)]
def translate_kwargs(kwargs, size):
new_kwargs = [{}] * size
if not kwargs:
return new_kwargs
for k, value_list in kwargs.items():
if len(value_list) != size:
raise ValueError(
f"Got {value_list} in kwargs which is not of the same length of samples {size}"
)
for kw, v in zip(new_kwargs, value_list):
kw[k] = v
return new_kwargs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment