diff --git a/bob/pad/face/database/batl.py b/bob/pad/face/database/batl.py index d88a9f0869bc07c6aaed5ffbbeacf9069469bed2..284e0bf5da59e20e416168b1fbc0f0948edbc9ef 100644 --- a/bob/pad/face/database/batl.py +++ b/bob/pad/face/database/batl.py @@ -150,6 +150,7 @@ class BatlPadDatabase(PadDatabase): original_extension='.h5', annotations_temp_dir="", landmark_detect_method="mtcnn", + exlude_attacks_list=None, **kwargs): """ **Parameters:** @@ -186,6 +187,12 @@ class BatlPadDatabase(PadDatabase): landmarks. Possible options: "dlib" or "mtcnn". Default: ``"mtcnn"``. + ``exlude_attacks_list`` : [str] + A list of strings defining which attacks should be excluded from + the training set. This shoould be handled in ``objects()`` method. + Currently handled attacks: "makeup". + Default: ``None``. + ``kwargs`` : dict The arguments of the :py:class:`bob.bio.base.database.BioDatabase` base class constructor. @@ -217,6 +224,7 @@ class BatlPadDatabase(PadDatabase): self.original_extension = original_extension self.annotations_temp_dir = annotations_temp_dir self.landmark_detect_method = landmark_detect_method + self.exlude_attacks_list = exlude_attacks_list @property def original_directory(self): @@ -378,6 +386,11 @@ class BatlPadDatabase(PadDatabase): The protocol is dependent on your database. If you do not have protocols defined, just ignore this field. + ``groups`` : :py:class:`str` + OR a list of strings. + The groups of which the clients should be returned. + Usually, groups are one or more elements of ('train', 'dev', 'eval') + ``purposes`` : :obj:`str` or [:obj:`str`] The purposes for which File objects should be retrieved. Usually it is either 'real' or 'attack'. @@ -433,7 +446,14 @@ class BatlPadDatabase(PadDatabase): groups=groups, purposes=purposes, **kwargs) + if groups == 'train' or 'train' in groups and len(groups) == 1: + # exclude "makeup" case + if self.exlude_attacks_list is not None and "makeup" in self.exlude_attacks_list: + + files = [f for f in files if os.path.split(f.path)[-1].split("_")[-2:-1][0] != "5"] + files = [BatlPadFile(f, stream_type, max_frames) for f in files] + return files def annotations(self, f):