Commit 6438fe68 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch 'fix-dask-annotators' into 'master'

Fix annotators kwargs

See merge request !205
parents 42322c0a 28ca9ae5
Pipeline #45365 failed with stages
in 6 minutes and 54 seconds
......@@ -5,6 +5,13 @@ from .. import load_resource
logger = logging.getLogger(__name__)
def _isolate_kwargs(kwargs_dict, index):
"""
Returns the kwargs to pass down to an annotator.
Each annotator is expecting a batch of samples and the corresponding kwargs.
"""
return {k:[v[index]] for k,v in kwargs_dict.items()}
class FailSafe(Annotator):
"""A fail-safe annotator.
......@@ -35,14 +42,24 @@ class FailSafe(Annotator):
self.only_required_keys = only_required_keys
def transform(self, sample_batch, **kwargs):
"""
Takes a batch of data and tries annotating them while unsuccessful.
Tries each annotator given at the creation of FailSafe when the previous
one fails.
Each ``kwargs`` value is a list of parameters, with each element of those
lists corresponding to each element of ``sample_batch`` (for example:
with ``[s1, s2, ...]`` as ``samples_batch``, ``kwargs['annotations']``
should contain ``[{<s1_annotations>}, {<s2_annotations>}, ...]``).
"""
if 'annotations' not in kwargs or kwargs['annotations'] is None:
kwargs['annotations'] = {}
all_annotations = []
for sample in sample_batch:
annotations = kwargs['annotations'].copy()
kwargs['annotations'] = [{}] * len(sample_batch)
# Executes on each sample and corresponding existing annotations
for index, (sample, annotations) in enumerate(zip(sample_batch, kwargs['annotations'])):
for annotator in self.annotators:
try:
annot = annotator([sample], **kwargs)[0]
annot = annotator([sample], **_isolate_kwargs(kwargs, index))[0]
except Exception:
logger.debug(
"The annotator `%s' failed to annotate!", annotator,
......@@ -64,5 +81,4 @@ class FailSafe(Annotator):
for key in list(annotations.keys()):
if key not in self.required_keys:
del annotations[key]
all_annotations.append(annotations)
return all_annotations
return kwargs['annotations']
......@@ -34,8 +34,8 @@ def annotate_common_options(func):
required=True,
cls=ResourceOption,
entry_point_group="bob.bio.annotator",
help="A Transformer instance that takes a series of sample and returns "
"the modified samples with annotations as a dictionary.",
help="An annotator (instance of class inheriting from "
"bob.bio.base.Annotator) or an annotator resource name.",
)
@click.option(
"--output-dir",
......@@ -148,15 +148,14 @@ def annotate(
else:
scheduler="single-threaded"
# Splits the samples list into bags
dask_bags = to_dask_bags.transform(samples)
logger.info(f"Saving annotations in {output_dir}.")
logger.info(f"Annotating {len(samples)} samples...")
dask_bags = to_dask_bags.transform(samples)
annotator.transform(dask_bags).compute(scheduler=scheduler)
if dask_client is not None:
logger.info("Shutdown workers...")
dask_client.shutdown()
logger.info("Done.")
logger.info("All annotations written.")
@click.command(
......@@ -254,6 +253,7 @@ def annotate_samples(
)
for s in samples
]
# Splits the samples list into bags
dask_bags = to_dask_bags.transform(samples_obj)
......@@ -261,7 +261,4 @@ def annotate_samples(
logger.info(f"Annotating {len(samples_obj)} samples...")
annotator.transform(dask_bags).compute(scheduler=scheduler)
if dask_client is not None:
logger.info("Shutdown workers...")
dask_client.shutdown()
logger.info("Done.")
logger.info("All annotations written.")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment