diff --git a/bob/pad/face/database/batl.py b/bob/pad/face/database/batl.py index a69428a8f7b1106008ff8939db7f50d81dadcaf1..91ecfcbeae8a47d3ea658b0c451c4798a1ce12d7 100644 --- a/bob/pad/face/database/batl.py +++ b/bob/pad/face/database/batl.py @@ -10,14 +10,14 @@ from bob.db.batl.batl_config import BATL_CONFIG import pkg_resources from batl.utils.data import load_data_config -class BATLPadFile(PadFile): +class BatlPadFile(PadFile): """ A high level implementation of the File class for the BATL database. """ def __init__(self, f, - stream_types, # a list of streams to be loaded + stream_type, # a list of streams to be loaded max_frames, reference_stream_type="color", warp_to_reference=True, @@ -39,13 +39,13 @@ class BATLPadFile(PadFile): else: attack_type = None - super(BATLPadFile, self).__init__( + super(BatlPadFile, self).__init__( client_id=f.client_id, path=f.path, attack_type=attack_type, file_id=f.id) - self.stream_types = stream_types + self.stream_type = stream_type self.max_frames = max_frames self.reference_stream_type = reference_stream_type # "color" self.data_format_config = load_data_config(pkg_resources.resource_filename('batl.utils', 'config/idiap_hdf5_data_config.json')) @@ -58,7 +58,7 @@ class BATLPadFile(PadFile): data = f.load(self, directory=directory, extension=extension, - stream_types=self.stream_types, # TODO: this parameter is currently missing in bob.db.batl, add it there + stream_type=self.stream_type, # TODO: this parameter is currently missing in bob.db.batl, add it there reference_stream_type=self.reference_stream_type, data_format_config=self.data_format_config, warp_to_reference=self.warp_to_reference, @@ -72,7 +72,7 @@ class BATLPadFile(PadFile): return data -class BATLPadDatabase(PadDatabase): +class BatlPadDatabase(PadDatabase): """ A high level implementation of the Database class for the BATL database. @@ -91,7 +91,13 @@ class BATLPadDatabase(PadDatabase): protocol : str or None The name of the protocol that defines the default experimental - setup for this database. + setup for this database. Also a "complex" protocols can be + parsed. + For example: + "grandtest-color-5" - grandtest protocol, color data only, use 5 first frames. + "grandtest-depth-5" - grandtest protocol, depth data only, use 5 first frames. + "grandtest-color" - grandtest protocol, depth data only, use all frames. + See the ``parse_protocol`` method of this class. original_directory : str The directory where the original data of the database are stored. @@ -118,7 +124,7 @@ class BATLPadDatabase(PadDatabase): 'eval') # names are expected to be like that in objects() function # Always use super to call parent class methods. - super(BATLPadDatabase, self).__init__( + super(BatlPadDatabase, self).__init__( name='batl', protocol=protocol, original_directory=original_directory, @@ -133,6 +139,45 @@ class BATLPadDatabase(PadDatabase): def original_directory(self, value): self.db.original_directory = value + def parse_protocol(self, protocol): + """ + Parse the protocol name, which is give as a string. + An example of protocols it can parse: + "grandtest-color-5" - grandtest protocol, color data only, use 5 first frames. + "grandtest-depth-5" - grandtest protocol, depth data only, use 5 first frames. + "grandtest-color" - grandtest protocol, depth data only, use all frames. + + **Parameters:** + + ``protocol`` : str + Protocol name to be parsed. Example: "grandtest-depth-5" . + + **Returns:** + + ``protocol`` : str + The name of the protocol as defined in the low level db interface. + + ``stream_types`` : str + The name of the channel/stream_type to be loaded. + + ``max_frames`` : int + The number of frames to be loaded. + """ + + components = protocol.split("-") + + components = components + [None, None] + + components = components[0:3] + + protocol, stream_types, max_frames = components + + if max_frames is not None: + + max_frames = int(max_frames) + + return protocol, stream_types, max_frames + def objects(self, protocol=None, group=None, @@ -164,14 +209,18 @@ class BATLPadDatabase(PadDatabase): A list of BATLPadFile objects. """ + protocol, stream_types, max_frames = self.parse_protocol(protocol) + # Convert group names to low-level group names here. groups = self.convert_names_to_lowlevel( groups, self.low_level_group_names, self.high_level_group_names) # Since this database was designed for PAD experiments, nothing special # needs to be done here. files = self.db.objects(protocol=protocol, groups=groups, purposes=purposes **kwargs) - files = [BATLPadFile(f) for f in files] + files = [BatlPadFile(f, stream_type, max_frames) for f in files] return files def annotations(self, f): pass + +