Skip to content
Snippets Groups Projects

WIP: Options to use database annotations function instead of PadFile.annotations

Closed Vincent POLLET requested to merge use_legacy_db_annotations into master
1 file
+ 7
4
Compare changes
  • Side-by-side
  • Inline
@@ -9,13 +9,15 @@ from .abstract_classes import Database
@@ -9,13 +9,15 @@ from .abstract_classes import Database
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
def _padfile_to_delayed_sample(padfile, database):
def _padfile_to_delayed_sample(padfile, database, use_db_annotations):
return DelayedSample(
return DelayedSample(
load=padfile.load,
load=padfile.load,
subject=str(padfile.client_id),
subject=str(padfile.client_id),
attack_type=padfile.attack_type,
attack_type=padfile.attack_type,
key=padfile.path,
key=padfile.path,
delayed_attributes=dict(annotations=lambda : padfile.annotations),
delayed_attributes=dict(
 
annotations=lambda: database.annotations(padfile) if use_db_annotations else padfile.annotations
 
),
is_bonafide=padfile.attack_type is None,
is_bonafide=padfile.attack_type is None,
)
)
@@ -34,16 +36,17 @@ class DatabaseConnector(Database):
@@ -34,16 +36,17 @@ class DatabaseConnector(Database):
"""
"""
def __init__(
def __init__(
self, database, annotation_type="eyes-center", fixed_positions=None, **kwargs
self, database, annotation_type="eyes-center", use_db_annotations=False, fixed_positions=None, **kwargs
):
):
super().__init__(**kwargs)
super().__init__(**kwargs)
self.database = database
self.database = database
self.annotation_type = annotation_type
self.annotation_type = annotation_type
 
self.use_db_annotations = use_db_annotations
self.fixed_positions = fixed_positions
self.fixed_positions = fixed_positions
def fit_samples(self):
def fit_samples(self):
objects = self.database.training_files(flat=True)
objects = self.database.training_files(flat=True)
return [_padfile_to_delayed_sample(k, self.database) for k in objects]
return [_padfile_to_delayed_sample(k, self.database, self.use_db_annotations) for k in objects]
def predict_samples(self, group="dev"):
def predict_samples(self, group="dev"):
objects = self.database.all_files(groups=group, flat=True)
objects = self.database.all_files(groups=group, flat=True)
Loading