Commit 36ce6e95 authored by Vincent POLLET's avatar Vincent POLLET
Browse files

Options to use database annotations function instead of PadFile.annotations

parent 8f5cf008
Pipeline #48875 passed with stage
in 7 minutes and 28 seconds
......@@ -9,13 +9,15 @@ from .abstract_classes import Database
logger = logging.getLogger(__name__)
def _padfile_to_delayed_sample(padfile, database):
def _padfile_to_delayed_sample(padfile, database, use_db_annotations):
return DelayedSample(
load=padfile.load,
subject=str(padfile.client_id),
attack_type=padfile.attack_type,
key=padfile.path,
delayed_attributes=dict(annotations=lambda : padfile.annotations),
delayed_attributes=dict(
annotations=lambda: database.annotations(padfile) if use_db_annotations else padfile.annotations
),
is_bonafide=padfile.attack_type is None,
)
......@@ -34,16 +36,17 @@ class DatabaseConnector(Database):
"""
def __init__(
self, database, annotation_type="eyes-center", fixed_positions=None, **kwargs
self, database, annotation_type="eyes-center", use_db_annotations=False, fixed_positions=None, **kwargs
):
super().__init__(**kwargs)
self.database = database
self.annotation_type = annotation_type
self.use_db_annotations = use_db_annotations
self.fixed_positions = fixed_positions
def fit_samples(self):
objects = self.database.training_files(flat=True)
return [_padfile_to_delayed_sample(k, self.database) for k in objects]
return [_padfile_to_delayed_sample(k, self.database, self.use_db_annotations) for k in objects]
def predict_samples(self, group="dev"):
objects = self.database.all_files(groups=group, flat=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment