diff --git a/bob/bio/base/annotator/FailSafe.py b/bob/bio/base/annotator/FailSafe.py index 091e55300e58a467fe5b06170d66f71029ae4746..bfad2513ebcb00bca42c5ef2e61c3326616af2f8 100644 --- a/bob/bio/base/annotator/FailSafe.py +++ b/bob/bio/base/annotator/FailSafe.py @@ -19,9 +19,12 @@ class FailSafe(Annotator): required_keys : list A list of keys that should be available in annotations to stop trying different annotators. + only_required_keys : bool + If True, the annotations will only contain the ``required_keys``. """ - def __init__(self, annotators, required_keys, **kwargs): + def __init__(self, annotators, required_keys, only_required_keys=False, + **kwargs): super(FailSafe, self).__init__(**kwargs) self.annotators = [] for annotator in annotators: @@ -29,6 +32,7 @@ class FailSafe(Annotator): annotator = load_resource(annotator, 'annotator') self.annotators.append(annotator) self.required_keys = list(required_keys) + self.only_required_keys = only_required_keys def annotate(self, sample, **kwargs): if 'annotations' not in kwargs or kwargs['annotations'] is None: @@ -51,4 +55,8 @@ class FailSafe(Annotator): else: # this else is for the for loop # we don't want to return half of the annotations kwargs['annotations'] = None + if self.only_required_keys: + for key in list(kwargs['annotations'].keys()): + if key not in self.required_keys: + del kwargs['annotations'][key] return kwargs['annotations'] diff --git a/bob/bio/base/test/test_annotators.py b/bob/bio/base/test/test_annotators.py index b8735b1e748e5c8d637fc429db3bfe3b1b551714..072db6f092a807ee26e3ab2a8ddba036c0713fc3 100644 --- a/bob/bio/base/test/test_annotators.py +++ b/bob/bio/base/test/test_annotators.py @@ -3,6 +3,7 @@ import os import shutil from click.testing import CliRunner from bob.bio.base.script.annotate import annotate +from bob.bio.base.annotator import Callable, FailSafe from bob.db.base import read_annotation_file @@ -31,3 +32,18 @@ def test_annotate(): assert annot['bottomright'] == [112, 92] finally: shutil.rmtree(tmp_dir) + + +def dummy_extra_key_annotator(data, **kwargs): + return {'leye': 0, 'reye': 0, 'topleft': 0} + + +def test_failsafe(): + annotator = FailSafe([Callable(dummy_extra_key_annotator)], + ['leye', 'reye']) + assert all(key in annotator(1) for key in ['leye', 'reye', 'topleft']) + + annotator = FailSafe([Callable(dummy_extra_key_annotator)], + ['leye', 'reye'], True) + assert all(key in annotator(1) for key in ['leye', 'reye']) + assert all(key not in annotator(1) for key in ['topleft'])