Skip to content
Snippets Groups Projects
Commit 40331e70 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[annotators][FailSafe] improvements

parent d7470dfe
Branches
No related tags found
1 merge request!210[annotators][FailSafe] improvements
Pipeline #45878 passed
import logging
import six
from . import Annotator
from .. import load_resource
logger = logging.getLogger(__name__)
def _isolate_kwargs(kwargs_dict, index):
"""
Returns the kwargs to pass down to an annotator.
Each annotator is expecting a batch of samples and the corresponding kwargs.
"""
return {k:[v[index]] for k,v in kwargs_dict.items()}
class FailSafe(Annotator):
"""A fail-safe annotator.
......@@ -30,18 +22,47 @@ class FailSafe(Annotator):
If True, the annotations will only contain the ``required_keys``.
"""
def __init__(self, annotators, required_keys, only_required_keys=False,
**kwargs):
def __init__(self, annotators, required_keys, only_required_keys=False, **kwargs):
super(FailSafe, self).__init__(**kwargs)
self.annotators = []
for annotator in annotators:
if isinstance(annotator, six.string_types):
annotator = load_resource(annotator, 'annotator')
if isinstance(annotator, str):
annotator = load_resource(annotator, "annotator")
self.annotators.append(annotator)
self.required_keys = list(required_keys)
self.only_required_keys = only_required_keys
def transform(self, sample_batch, **kwargs):
def annotate(self, sample, **kwargs):
if "annotations" not in kwargs or kwargs["annotations"] is None:
kwargs["annotations"] = {}
for annotator in self.annotators:
try:
annotations = annotator.transform(
[sample], **{k: [v] for k, v in kwargs.items()}
)[0]
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator, exc_info=True
)
annotations = None
if not annotations:
logger.debug("Annotator `%s' returned empty annotations.", annotator)
else:
logger.debug("Annotator `%s' succeeded!", annotator)
kwargs["annotations"].update(annotations or {})
# check if we have all the required annotations
if all(key in kwargs["annotations"] for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
kwargs["annotations"] = None
if self.only_required_keys:
for key in list(kwargs["annotations"].keys()):
if key not in self.required_keys:
del kwargs["annotations"][key]
return kwargs["annotations"]
def transform(self, samples, **kwargs):
"""
Takes a batch of data and tries annotating them while unsuccessful.
......@@ -53,32 +74,22 @@ class FailSafe(Annotator):
with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
"""
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = [{}] * len(sample_batch)
# Executes on each sample and corresponding existing annotations
for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])):
for annotator in self.annotators:
try:
annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0]
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator,
exc_info=True)
annot = None
if not annot:
logger.debug(
"Annotator `%s' returned empty annotations.", annotator)
else:
logger.debug("Annotator `%s' succeeded!", annotator)
annotations.update(annot or {})
# check if we have all the required annotations
if all(key in annotations for key in self.required_keys):
break
else: # this else is for the for loop
# we don't want to return half of the annotations
annotations = None
if self.only_required_keys:
for key in list(annotations.keys()):
if key not in self.required_keys:
del annotations[key]
return kwargs['annotations']
kwargs = translate_kwargs(kwargs, len(samples))
return [self.annotate(sample, **kw) for sample, kw in zip(samples, kwargs)]
def translate_kwargs(kwargs, size):
new_kwargs = [{}] * size
if not kwargs:
return new_kwargs
for k, value_list in kwargs.items():
if len(value_list) != size:
raise ValueError(
f"Got {value_list} in kwargs which is not of the same length of samples {size}"
)
for kw, v in zip(new_kwargs, value_list):
kw[k] = v
return new_kwargs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment