Skip to content
Snippets Groups Projects

Added an option to exclude specific types of attacks from train set in BATL DB

Merged Olegs NIKISINS requested to merge batl_db_update into master
1 file
+ 20
0
Compare changes
  • Side-by-side
  • Inline
@@ -150,6 +150,7 @@ class BatlPadDatabase(PadDatabase):
original_extension='.h5',
annotations_temp_dir="",
landmark_detect_method="mtcnn",
exlude_attacks_list=None,
**kwargs):
"""
**Parameters:**
@@ -186,6 +187,12 @@ class BatlPadDatabase(PadDatabase):
landmarks. Possible options: "dlib" or "mtcnn".
Default: ``"mtcnn"``.
``exlude_attacks_list`` : [str]
A list of strings defining which attacks should be excluded from
the training set. This shoould be handled in ``objects()`` method.
Currently handled attacks: "makeup".
Default: ``None``.
``kwargs`` : dict
The arguments of the :py:class:`bob.bio.base.database.BioDatabase`
base class constructor.
@@ -217,6 +224,7 @@ class BatlPadDatabase(PadDatabase):
self.original_extension = original_extension
self.annotations_temp_dir = annotations_temp_dir
self.landmark_detect_method = landmark_detect_method
self.exlude_attacks_list = exlude_attacks_list
@property
def original_directory(self):
@@ -378,6 +386,11 @@ class BatlPadDatabase(PadDatabase):
The protocol is dependent on your database.
If you do not have protocols defined, just ignore this field.
``groups`` : :py:class:`str`
OR a list of strings.
The groups of which the clients should be returned.
Usually, groups are one or more elements of ('train', 'dev', 'eval')
``purposes`` : :obj:`str` or [:obj:`str`]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
@@ -433,7 +446,14 @@ class BatlPadDatabase(PadDatabase):
groups=groups,
purposes=purposes, **kwargs)
if groups == 'train' or 'train' in groups and len(groups) == 1:
# exclude "makeup" case
if self.exlude_attacks_list is not None and "makeup" in self.exlude_attacks_list:
files = [f for f in files if os.path.split(f.path)[-1].split("_")[-2:-1][0] != "5"]
files = [BatlPadFile(f, stream_type, max_frames) for f in files]
return files
def annotations(self, f):
Loading