Skip to content
Snippets Groups Projects
Commit 92c95a47 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Merge branch 'batl_hldi' into 'master'

Updated the HLDI of BATL DB, added FunnyEyes fix, and protocol joining test and dev sets

See merge request !65
parents dcaf6888 aa938823
Branches
Tags
1 merge request!65Updated the HLDI of BATL DB, added FunnyEyes fix, and protocol joining test and dev sets
Pipeline #
......@@ -164,6 +164,9 @@ class BatlPadDatabase(PadDatabase):
"nowig-depth-5" - nowig protocol, depth data only,
use 5 first frames.
"nowig-color" - nowig protocol, depth data only, use all frames.
"nowig-infrared-50-join_train_dev" - nowig protocol,
infrared data only, use 50 frames, join train and dev sets forming
a single large training set.
See the ``parse_protocol`` method of this class.
``original_directory`` : str
......@@ -246,11 +249,28 @@ class BatlPadDatabase(PadDatabase):
``max_frames`` : int
The number of frames to be loaded.
``extra`` : str
An extra string which is handled in ``self.objects()`` method.
Extra strings which are currently handled are defined in
``possible_extras`` of this function.
For example, if ``extra="join_train_dev"``, the train and dev
sets will be joined in ``self.objects()``,
forming a single training set.
"""
possible_extras = ['join_train_dev']
components = protocol.split("-")
components = components + [None, None]
extra = [item for item in possible_extras if item in components]
extra = extra[0] if extra else None
if extra is not None:
components.remove(extra)
components += [None, None]
components = components[0:3]
......@@ -260,7 +280,86 @@ class BatlPadDatabase(PadDatabase):
max_frames = int(max_frames)
return protocol, stream_type, max_frames
return protocol, stream_type, max_frames, extra
def _fix_funny_eyes_in_objects(self, protocol, groups, purposes):
"""
This function redistributes FunnyEyes PAs accross 'train', 'dev' and
'eval' sets in the following way.
Original (low-level DB) distribution is as follows:
'train' = N1
'dev' = N2
'eval' = N3
After this function is applied the distribution is:
'train' = N1 + 1/2*N2
'dev' = N2 - 1/2*N2
'eval' = N3
**Parameters:**
``protocol`` : str
The protocol for which the clients should be retrieved.
``groups`` : :py:class:`str`
OR a list of strings.
The groups of which the clients should be returned.
Usually, groups are one or more elements of ('train', 'dev', 'eval')
``purposes`` : :obj:`str` or [:obj:`str`]
The purposes for which File objects should be retrieved.
Usually it is either 'real' or 'attack'.
**Returns:**
``files`` : [VideoFile]
A list of VideoFile objects defined in BATL Low Level Database
Interface.
"""
if groups is None:
groups = self.low_level_group_names
files_train = []
files_dev = []
files_eval = []
if groups == 'train' or 'train' in groups:
files_train = self.db.objects(protocol=protocol, groups='train', purposes=purposes)
files_to_append = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)
exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes
files_to_append = [f for f in files_to_append if f.path[-5:] in exclude]
files_to_append = files_to_append[0:int(len(files_to_append)/2)] # append HALF of files from "dev" to "train" set
files_train = files_train + files_to_append
if groups == 'validation' or 'validation' in groups:
files_dev = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)
exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes
files_to_append_1 = [f for f in files_dev if f.path[-5:] in exclude] # "dev" files containing FunnyEyes
files_to_append_1 = files_to_append_1[-int(len(files_to_append_1)/2):] # second HALF of "dev" files containing FunnyEyes
files_to_append_2 = [f for f in files_dev if f.path[-5:] not in exclude] # "dev" set without FunnyEyes
files_dev = files_to_append_1 + files_to_append_2
if groups == 'test' or 'test' in groups:
files_eval = self.db.objects(protocol=protocol, groups='test', purposes=purposes) # this group remain unchanged
files = files_train + files_dev + files_eval
return files
def objects(self,
protocol=None,
......@@ -301,17 +400,38 @@ class BatlPadDatabase(PadDatabase):
if purposes is None:
purposes = ['real', 'attack']
protocol, stream_type, max_frames = self.parse_protocol(protocol)
protocol, stream_type, max_frames, extra = self.parse_protocol(protocol)
# Convert group names to low-level group names here.
groups = self.convert_names_to_lowlevel(
groups, self.low_level_group_names, self.high_level_group_names)
# Since this database was designed for PAD experiments, nothing special
# needs to be done here.
files = self.db.objects(protocol=protocol,
groups=groups,
purposes=purposes, **kwargs)
if not isinstance(groups, list) and groups is not None: # if a single group is given make it a list
groups = list(groups)
if extra is not None and "join_train_dev" in extra:
if groups == ['train']: # join "train" and "dev" sets
files = self.db.objects(protocol=protocol,
groups=['train', 'validation'],
purposes=purposes, **kwargs)
# return ALL data if "train" and "some other" set/sets are requested
elif len(groups)>=2 and 'train' in groups:
files = self.db.objects(protocol=protocol,
groups=self.low_level_group_names,
purposes=purposes, **kwargs)
# addresses the cases when groups=['validation'] or ['test'] or ['validation', 'test']:
else:
files = self.db.objects(protocol=protocol,
groups=['test'],
purposes=purposes, **kwargs)
else:
files = self._fix_funny_eyes_in_objects(protocol=protocol,
groups=groups,
purposes=purposes, **kwargs)
files = [BatlPadFile(f, stream_type, max_frames) for f in files]
return files
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment