Skip to content
Snippets Groups Projects
Commit affe8037 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Updated BATL HLDI, fixed distribution of FunnyEyes, added option merging train and dev sets

parent dcaf6888
No related branches found
No related tags found
1 merge request!65Updated the HLDI of BATL DB, added FunnyEyes fix, and protocol joining test and dev sets
Pipeline #
......@@ -164,6 +164,9 @@ class BatlPadDatabase(PadDatabase):
"nowig-depth-5" - nowig protocol, depth data only,
use 5 first frames.
"nowig-color" - nowig protocol, depth data only, use all frames.
"nowig-infrared-50-join_train_dev" - nowig protocol,
infrared data only, use 50 frames, join train and dev sets forming
a single large training set.
See the ``parse_protocol`` method of this class.
``original_directory`` : str
......@@ -246,11 +249,28 @@ class BatlPadDatabase(PadDatabase):
``max_frames`` : int
The number of frames to be loaded.
``extra`` : str
An extra string which is handled in ``self.objects()`` method.
Extra strings which are currently handled are defined in
``possible_extras`` of this function.
For example, if ``extra="join_train_dev"``, the train and dev
sets will be joined in ``self.objects()``,
forming a single training set.
"""
possible_extras = ['join_train_dev']
components = protocol.split("-")
components = components + [None, None]
extra = [item for item in possible_extras if item in components]
extra = extra[0] if extra else None
if extra is not None:
components.remove(extra)
components += [None, None]
components = components[0:3]
......@@ -260,7 +280,82 @@ class BatlPadDatabase(PadDatabase):
max_frames = int(max_frames)
return protocol, stream_type, max_frames
return protocol, stream_type, max_frames, extra
def _fix_funny_eyes_in_objects(self, protocol, groups, purposes):
"""
This function redistributes FunnyEyes PAs accross 'train', 'dev' and
'eval' sets in the following way.
Original (low-level DB) distribution is as follows:
'train' = 0
'dev' = 27
'eval' = 8
After this function is applied the distribution is:
'train' = 19
'dev' = 8
'eval' = 8
**Parameters:**
``protocol`` : str
The protocol for which the clients should be retrieved.
``groups`` : :py:class:`str`
OR a list of strings.
The groups of which the clients should be returned.
Usually, groups are one or more elements of ('train', 'dev', 'eval')
``purposes`` : :obj:`str` or [:obj:`str`]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
**Returns:**
``files`` : [VideoFile]
A list of VideoFile objects defined in BATL Low Level Database
Interface.
"""
if groups is None:
groups = self.low_level_group_names
files_train = []
files_dev = []
files_eval = []
if groups == 'train' or 'train' in groups:
files_train = self.db.objects(protocol=protocol, groups='train', purposes=purposes)
files_to_append = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)
exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes
files_to_append = [f for f in files_to_append if f.path[-5:] in exclude][0:19] # append 19 files from "dev" to "train" set
files_train = files_train + files_to_append
if groups == 'validation' or 'validation' in groups:
files_dev = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)
exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes
files_to_append_1 = [f for f in files_dev if f.path[-5:] in exclude][-8:] # 8 "dev" files containing FunnyEyes
files_to_append_2 = [f for f in files_dev if f.path[-5:] not in exclude] # "dev" set without FunnyEyes
files_dev = files_to_append_1 + files_to_append_2
if groups == 'test' or 'test' in groups:
files_eval = self.db.objects(protocol=protocol, groups='test', purposes=purposes) # this group remain unchanged
files = files_train + files_dev + files_eval
return files
def objects(self,
protocol=None,
......@@ -301,17 +396,38 @@ class BatlPadDatabase(PadDatabase):
if purposes is None:
purposes = ['real', 'attack']
protocol, stream_type, max_frames = self.parse_protocol(protocol)
protocol, stream_type, max_frames, extra = self.parse_protocol(protocol)
# Convert group names to low-level group names here.
groups = self.convert_names_to_lowlevel(
groups, self.low_level_group_names, self.high_level_group_names)
# Since this database was designed for PAD experiments, nothing special
# needs to be done here.
files = self.db.objects(protocol=protocol,
groups=groups,
purposes=purposes, **kwargs)
if not isinstance(groups, list) and groups is not None: # if a single group is given make it a list
groups = list(groups)
if extra is not None and "join_train_dev" in extra:
if groups == ['train']: # join "train" and "dev" sets
files = self.db.objects(protocol=protocol,
groups=['train', 'validation'],
purposes=purposes, **kwargs)
# return ALL data if "train" and "some other" set/sets are requested
elif len(groups)>=2 and 'train' in groups:
files = self.db.objects(protocol=protocol,
groups=self.low_level_group_names,
purposes=purposes, **kwargs)
# addresses the cases when groups=['validation'] or ['test'] or ['validation', 'test']:
else:
files = self.db.objects(protocol=protocol,
groups=['test'],
purposes=purposes, **kwargs)
else:
files = self._fix_funny_eyes_in_objects(protocol=protocol,
groups=groups,
purposes=purposes, **kwargs)
files = [BatlPadFile(f, stream_type, max_frames) for f in files]
return files
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment