Skip to content
Snippets Groups Projects
Commit 93b33fe3 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Some bug fixes in BATL HLDI, objects(), etc.

parent 2578a83d
No related branches found
No related tags found
1 merge request!59Added HLDI for the BATL database, added optional data normalization in FaceCropAlign
Pipeline #
...@@ -14,6 +14,8 @@ from bob.pad.face.preprocessor.FaceCropAlign import detect_face_landmarks_in_ima ...@@ -14,6 +14,8 @@ from bob.pad.face.preprocessor.FaceCropAlign import detect_face_landmarks_in_ima
import json import json
import os
class BatlPadFile(PadFile): class BatlPadFile(PadFile):
""" """
A high level implementation of the File class for the BATL A high level implementation of the File class for the BATL
...@@ -38,7 +40,7 @@ class BatlPadFile(PadFile): ...@@ -38,7 +40,7 @@ class BatlPadFile(PadFile):
self.f = f self.f = f
if f.is_attack(): if f.is_attack():
attack = batl_config[f.type_id] attack = BATL_CONFIG[f.type_id]
attack_type = '{} : {}'.format(attack['name'], attack['pai'][f.pai_id]) attack_type = '{} : {}'.format(attack['name'], attack['pai'][f.pai_id])
else: else:
attack_type = None attack_type = None
...@@ -60,7 +62,7 @@ class BatlPadFile(PadFile): ...@@ -60,7 +62,7 @@ class BatlPadFile(PadFile):
def load(self, directory=None, extension='.hdf5', frame_selector=FrameSelector(selection_style='all')): def load(self, directory=None, extension='.hdf5', frame_selector=FrameSelector(selection_style='all')):
data = f.load(self, directory=directory, data = self.f.load(self, directory=directory,
extension=extension, extension=extension,
modality=self.stream_type, # TODO: this parameter is currently missing in bob.db.batl, add it there modality=self.stream_type, # TODO: this parameter is currently missing in bob.db.batl, add it there
reference_stream_type=self.reference_stream_type, reference_stream_type=self.reference_stream_type,
...@@ -169,7 +171,7 @@ class BatlPadDatabase(PadDatabase): ...@@ -169,7 +171,7 @@ class BatlPadDatabase(PadDatabase):
``protocol`` : str ``protocol`` : str
The name of the protocol as defined in the low level db interface. The name of the protocol as defined in the low level db interface.
``stream_types`` : str ``stream_type`` : str
The name of the channel/stream_type to be loaded. The name of the channel/stream_type to be loaded.
``max_frames`` : int ``max_frames`` : int
...@@ -182,19 +184,19 @@ class BatlPadDatabase(PadDatabase): ...@@ -182,19 +184,19 @@ class BatlPadDatabase(PadDatabase):
components = components[0:3] components = components[0:3]
protocol, stream_types, max_frames = components protocol, stream_type, max_frames = components
if max_frames is not None: if max_frames is not None:
max_frames = int(max_frames) max_frames = int(max_frames)
return protocol, stream_types, max_frames return protocol, stream_type, max_frames
def objects(self, def objects(self,
protocol=None, protocol=None,
groups=None, groups=None,
purposes=None, purposes=None,
sessions=None, model_ids=None,
**kwargs): **kwargs):
""" """
This function returns lists of BATLPadFile objects, which fulfill the This function returns lists of BATLPadFile objects, which fulfill the
...@@ -221,17 +223,17 @@ class BatlPadDatabase(PadDatabase): ...@@ -221,17 +223,17 @@ class BatlPadDatabase(PadDatabase):
A list of BATLPadFile objects. A list of BATLPadFile objects.
""" """
protocol, stream_types, max_frames = self.parse_protocol(protocol) protocol, stream_type, max_frames = self.parse_protocol(protocol)
# Convert group names to low-level group names here. # Convert group names to low-level group names here.
groups = self.convert_names_to_lowlevel( groups = self.convert_names_to_lowlevel(
groups, self.low_level_group_names, self.high_level_group_names) groups, self.low_level_group_names, self.high_level_group_names)
# Since this database was designed for PAD experiments, nothing special # Since this database was designed for PAD experiments, nothing special
# needs to be done here. # needs to be done here.
files = self.db.objects(protocol=protocol, groups=groups, purposes=purposes **kwargs) # files = self.db.objects(protocol=protocol, groups=groups, purposes=purposes **kwargs)
files = self.db.objects(protocol=protocol, purposes=groups, **kwargs)
# files = self.db.objects(protocol=protocol, purposes=groups, **kwargs)
# #
# if purposes == ["real", "attack"]: # if purposes == ["real", "attack"]:
# #
...@@ -294,3 +296,8 @@ class BatlPadDatabase(PadDatabase): ...@@ -294,3 +296,8 @@ class BatlPadDatabase(PadDatabase):
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment