Commit abb90779 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

FailSafe annotator optionally return only the required keys

parent d3d878ac
Pipeline #18556 passed with stage
......@@ -19,9 +19,12 @@ class FailSafe(Annotator):
required_keys : list
A list of keys that should be available in annotations to stop trying
different annotators.
only_required_keys : bool
If True, the annotations will only contain the ``required_keys``.
"""
def __init__(self, annotators, required_keys, **kwargs):
def __init__(self, annotators, required_keys, only_required_keys=False,
**kwargs):
super(FailSafe, self).__init__(**kwargs)
self.annotators = []
for annotator in annotators:
......@@ -29,6 +32,7 @@ class FailSafe(Annotator):
annotator = load_resource(annotator, 'annotator')
self.annotators.append(annotator)
self.required_keys = list(required_keys)
self.only_required_keys = only_required_keys
def annotate(self, sample, **kwargs):
if 'annotations' not in kwargs or kwargs['annotations'] is None:
......@@ -51,4 +55,8 @@ class FailSafe(Annotator):
else: # this else is for the for loop
# we don't want to return half of the annotations
kwargs['annotations'] = None
if self.only_required_keys:
for key in list(kwargs['annotations'].keys()):
if key not in self.required_keys:
del kwargs['annotations'][key]
return kwargs['annotations']
......@@ -3,6 +3,7 @@ import os
import shutil
from click.testing import CliRunner
from bob.bio.base.script.annotate import annotate
from bob.bio.base.annotator import Callable, FailSafe
from bob.db.base import read_annotation_file
......@@ -31,3 +32,18 @@ def test_annotate():
assert annot['bottomright'] == [112, 92]
finally:
shutil.rmtree(tmp_dir)
def dummy_extra_key_annotator(data, **kwargs):
return {'leye': 0, 'reye': 0, 'topleft': 0}
def test_failsafe():
annotator = FailSafe([Callable(dummy_extra_key_annotator)],
['leye', 'reye'])
assert all(key in annotator(1) for key in ['leye', 'reye', 'topleft'])
annotator = FailSafe([Callable(dummy_extra_key_annotator)],
['leye', 'reye'], True)
assert all(key in annotator(1) for key in ['leye', 'reye'])
assert all(key not in annotator(1) for key in ['topleft'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment