From affe8037f58704628aa894624fb52b2cf477db93 Mon Sep 17 00:00:00 2001 From: Olegs NIKISINS <onikisins@italix03.idiap.ch> Date: Fri, 27 Apr 2018 10:14:03 +0200 Subject: [PATCH] Updated BATL HLDI, fixed distribution of FunnyEyes, added option merging train and dev sets --- bob/pad/face/database/batl.py | 132 +++++++++++++++++++++++++++++++--- 1 file changed, 124 insertions(+), 8 deletions(-) diff --git a/bob/pad/face/database/batl.py b/bob/pad/face/database/batl.py index a898ac5b..7facd4c3 100644 --- a/bob/pad/face/database/batl.py +++ b/bob/pad/face/database/batl.py @@ -164,6 +164,9 @@ class BatlPadDatabase(PadDatabase): "nowig-depth-5" - nowig protocol, depth data only, use 5 first frames. "nowig-color" - nowig protocol, depth data only, use all frames. + "nowig-infrared-50-join_train_dev" - nowig protocol, + infrared data only, use 50 frames, join train and dev sets forming + a single large training set. See the ``parse_protocol`` method of this class. ``original_directory`` : str @@ -246,11 +249,28 @@ class BatlPadDatabase(PadDatabase): ``max_frames`` : int The number of frames to be loaded. + + ``extra`` : str + An extra string which is handled in ``self.objects()`` method. + Extra strings which are currently handled are defined in + ``possible_extras`` of this function. + For example, if ``extra="join_train_dev"``, the train and dev + sets will be joined in ``self.objects()``, + forming a single training set. """ + possible_extras = ['join_train_dev'] + components = protocol.split("-") - components = components + [None, None] + extra = [item for item in possible_extras if item in components] + + extra = extra[0] if extra else None + + if extra is not None: + components.remove(extra) + + components += [None, None] components = components[0:3] @@ -260,7 +280,82 @@ class BatlPadDatabase(PadDatabase): max_frames = int(max_frames) - return protocol, stream_type, max_frames + return protocol, stream_type, max_frames, extra + + def _fix_funny_eyes_in_objects(self, protocol, groups, purposes): + """ + This function redistributes FunnyEyes PAs accross 'train', 'dev' and + 'eval' sets in the following way. + + Original (low-level DB) distribution is as follows: + 'train' = 0 + 'dev' = 27 + 'eval' = 8 + + After this function is applied the distribution is: + 'train' = 19 + 'dev' = 8 + 'eval' = 8 + + **Parameters:** + + ``protocol`` : str + The protocol for which the clients should be retrieved. + + ``groups`` : :py:class:`str` + OR a list of strings. + The groups of which the clients should be returned. + Usually, groups are one or more elements of ('train', 'dev', 'eval') + + ``purposes`` : :obj:`str` or [:obj:`str`] + The purposes for which File objects should be retrieved. + Usually it is either 'real' or 'attack'. + + **Returns:** + + ``files`` : [VideoFile] + A list of VideoFile objects defined in BATL Low Level Database + Interface. + """ + + if groups is None: + groups = self.low_level_group_names + + files_train = [] + files_dev = [] + files_eval = [] + + if groups == 'train' or 'train' in groups: + + files_train = self.db.objects(protocol=protocol, groups='train', purposes=purposes) + + files_to_append = self.db.objects(protocol=protocol, groups='validation', purposes=purposes) + + exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes + + files_to_append = [f for f in files_to_append if f.path[-5:] in exclude][0:19] # append 19 files from "dev" to "train" set + + files_train = files_train + files_to_append + + if groups == 'validation' or 'validation' in groups: + + files_dev = self.db.objects(protocol=protocol, groups='validation', purposes=purposes) + + exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes + + files_to_append_1 = [f for f in files_dev if f.path[-5:] in exclude][-8:] # 8 "dev" files containing FunnyEyes + + files_to_append_2 = [f for f in files_dev if f.path[-5:] not in exclude] # "dev" set without FunnyEyes + + files_dev = files_to_append_1 + files_to_append_2 + + if groups == 'test' or 'test' in groups: + + files_eval = self.db.objects(protocol=protocol, groups='test', purposes=purposes) # this group remain unchanged + + files = files_train + files_dev + files_eval + + return files def objects(self, protocol=None, @@ -301,17 +396,38 @@ class BatlPadDatabase(PadDatabase): if purposes is None: purposes = ['real', 'attack'] - protocol, stream_type, max_frames = self.parse_protocol(protocol) + protocol, stream_type, max_frames, extra = self.parse_protocol(protocol) # Convert group names to low-level group names here. groups = self.convert_names_to_lowlevel( groups, self.low_level_group_names, self.high_level_group_names) - # Since this database was designed for PAD experiments, nothing special - # needs to be done here. - files = self.db.objects(protocol=protocol, - groups=groups, - purposes=purposes, **kwargs) + if not isinstance(groups, list) and groups is not None: # if a single group is given make it a list + groups = list(groups) + + if extra is not None and "join_train_dev" in extra: + + if groups == ['train']: # join "train" and "dev" sets + files = self.db.objects(protocol=protocol, + groups=['train', 'validation'], + purposes=purposes, **kwargs) + + # return ALL data if "train" and "some other" set/sets are requested + elif len(groups)>=2 and 'train' in groups: + files = self.db.objects(protocol=protocol, + groups=self.low_level_group_names, + purposes=purposes, **kwargs) + + # addresses the cases when groups=['validation'] or ['test'] or ['validation', 'test']: + else: + files = self.db.objects(protocol=protocol, + groups=['test'], + purposes=purposes, **kwargs) + + else: + files = self._fix_funny_eyes_in_objects(protocol=protocol, + groups=groups, + purposes=purposes, **kwargs) files = [BatlPadFile(f, stream_type, max_frames) for f in files] return files -- GitLab