Commit 881425da authored by Yannick DAYER's avatar Yannick DAYER
Browse files

[annotator] Correct kwargs receiving and passing.

parent 42322c0a
......@@ -5,6 +5,13 @@ from .. import load_resource
logger = logging.getLogger(__name__)
def _isolate_kwargs(kwargs_dict, index):
Returns the kwargs to pass down to an annotator.
Each annotator is expecting a batch of samples and the corresponding kwargs.
return {k:[v[index]] for k,v in kwargs_dict.items()}
class FailSafe(Annotator):
"""A fail-safe annotator.
......@@ -35,14 +42,24 @@ class FailSafe(Annotator):
self.only_required_keys = only_required_keys
def transform(self, sample_batch, **kwargs):
Takes a batch of data and tries annotating them while unsuccessful.
Tries each annotator given at the creation of FailSafe when the previous
one fails.
Each ``kwargs`` value is a list of parameters, with each element of those
lists corresponding to each element of ``sample_batch`` (for example:
with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = {}
all_annotations = []
for sample in sample_batch:
annotations = kwargs['annotations'].copy()
kwargs['annotations'] = [{}] * len(sample_batch)
# Executes on each sample and corresponding existing annotations
for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])):
for annotator in self.annotators:
annot = annotator([sample], **kwargs)[0]
annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0]
except Exception:
"The annotator `%s' failed to annotate!", annotator,
......@@ -64,5 +81,4 @@ class FailSafe(Annotator):
for key in list(annotations.keys()):
if key not in self.required_keys:
del annotations[key]
return all_annotations
return kwargs['annotations']
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment