From affe8037f58704628aa894624fb52b2cf477db93 Mon Sep 17 00:00:00 2001
From: Olegs NIKISINS <onikisins@italix03.idiap.ch>
Date: Fri, 27 Apr 2018 10:14:03 +0200
Subject: [PATCH] Updated BATL HLDI, fixed distribution of FunnyEyes, added
 option merging train and dev sets

---
 bob/pad/face/database/batl.py | 132 +++++++++++++++++++++++++++++++---
 1 file changed, 124 insertions(+), 8 deletions(-)

diff --git a/bob/pad/face/database/batl.py b/bob/pad/face/database/batl.py
index a898ac5b..7facd4c3 100644
--- a/bob/pad/face/database/batl.py
+++ b/bob/pad/face/database/batl.py
@@ -164,6 +164,9 @@ class BatlPadDatabase(PadDatabase):
             "nowig-depth-5" - nowig protocol, depth data only,
             use 5 first frames.
             "nowig-color" - nowig protocol, depth data only, use all frames.
+            "nowig-infrared-50-join_train_dev" - nowig protocol,
+            infrared data only, use 50 frames, join train and dev sets forming
+            a single large training set.
             See the ``parse_protocol`` method of this class.
 
         ``original_directory`` : str
@@ -246,11 +249,28 @@ class BatlPadDatabase(PadDatabase):
 
         ``max_frames`` : int
             The number of frames to be loaded.
+
+        ``extra`` : str
+            An extra string which is handled in ``self.objects()`` method.
+            Extra strings which are currently handled are defined in
+            ``possible_extras`` of this function.
+            For example, if ``extra="join_train_dev"``, the train and dev
+            sets will be joined in ``self.objects()``,
+            forming a single training set.
         """
 
+        possible_extras = ['join_train_dev']
+
         components = protocol.split("-")
 
-        components = components + [None, None]
+        extra = [item for item in possible_extras if item in components]
+
+        extra = extra[0] if extra else None
+
+        if extra is not None:
+            components.remove(extra)
+
+        components += [None, None]
 
         components = components[0:3]
 
@@ -260,7 +280,82 @@ class BatlPadDatabase(PadDatabase):
 
             max_frames = int(max_frames)
 
-        return protocol, stream_type, max_frames
+        return protocol, stream_type, max_frames, extra
+
+    def _fix_funny_eyes_in_objects(self, protocol, groups, purposes):
+        """
+        This function redistributes FunnyEyes PAs accross 'train', 'dev' and
+        'eval' sets in the following way.
+
+        Original (low-level DB) distribution is as follows:
+        'train' = 0
+        'dev' = 27
+        'eval' = 8
+
+        After this function is applied the distribution is:
+        'train' = 19
+        'dev' = 8
+        'eval' = 8
+
+        **Parameters:**
+
+        ``protocol`` : str
+            The protocol for which the clients should be retrieved.
+
+        ``groups`` : :py:class:`str`
+            OR a list of strings.
+            The groups of which the clients should be returned.
+            Usually, groups are one or more elements of ('train', 'dev', 'eval')
+
+        ``purposes`` : :obj:`str` or [:obj:`str`]
+            The purposes for which File objects should be retrieved.
+            Usually it is either 'real' or 'attack'.
+
+        **Returns:**
+
+        ``files`` : [VideoFile]
+            A list of VideoFile objects defined in BATL Low Level Database
+            Interface.
+        """
+
+        if groups is None:
+            groups = self.low_level_group_names
+
+        files_train = []
+        files_dev = []
+        files_eval = []
+
+        if groups == 'train' or 'train' in groups:
+
+            files_train = self.db.objects(protocol=protocol, groups='train', purposes=purposes)
+
+            files_to_append = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)
+
+            exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes
+
+            files_to_append = [f for f in files_to_append if f.path[-5:] in exclude][0:19] # append 19 files from "dev" to "train" set
+
+            files_train = files_train + files_to_append
+
+        if groups == 'validation' or 'validation' in groups:
+
+            files_dev = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)
+
+            exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes
+
+            files_to_append_1 = [f for f in files_dev if f.path[-5:] in exclude][-8:] # 8 "dev" files containing FunnyEyes
+
+            files_to_append_2 = [f for f in files_dev if f.path[-5:] not in exclude] # "dev" set without FunnyEyes
+
+            files_dev = files_to_append_1 + files_to_append_2
+
+        if groups == 'test' or 'test' in groups:
+
+            files_eval = self.db.objects(protocol=protocol, groups='test', purposes=purposes) # this group remain unchanged
+
+        files = files_train + files_dev + files_eval
+
+        return files
 
     def objects(self,
                 protocol=None,
@@ -301,17 +396,38 @@ class BatlPadDatabase(PadDatabase):
         if purposes is None:
             purposes = ['real', 'attack']
 
-        protocol, stream_type, max_frames = self.parse_protocol(protocol)
+        protocol, stream_type, max_frames, extra = self.parse_protocol(protocol)
 
         # Convert group names to low-level group names here.
         groups = self.convert_names_to_lowlevel(
             groups, self.low_level_group_names, self.high_level_group_names)
-        # Since this database was designed for PAD experiments, nothing special
-        # needs to be done here.
 
-        files = self.db.objects(protocol=protocol,
-                                groups=groups,
-                                purposes=purposes, **kwargs)
+        if not isinstance(groups, list) and groups is not None:  # if a single group is given make it a list
+            groups = list(groups)
+
+        if extra is not None and "join_train_dev" in extra:
+
+            if groups == ['train']: # join "train" and "dev" sets
+                files = self.db.objects(protocol=protocol,
+                                        groups=['train', 'validation'],
+                                        purposes=purposes, **kwargs)
+
+            # return ALL data if "train" and "some other" set/sets are requested
+            elif len(groups)>=2 and 'train' in groups:
+                files = self.db.objects(protocol=protocol,
+                                        groups=self.low_level_group_names,
+                                        purposes=purposes, **kwargs)
+
+            # addresses the cases when groups=['validation'] or ['test'] or ['validation', 'test']:
+            else:
+                files = self.db.objects(protocol=protocol,
+                                        groups=['test'],
+                                        purposes=purposes, **kwargs)
+
+        else:
+            files = self._fix_funny_eyes_in_objects(protocol=protocol,
+                                                    groups=groups,
+                                                    purposes=purposes, **kwargs)
 
         files = [BatlPadFile(f, stream_type, max_frames) for f in files]
         return files
-- 
GitLab