From 219546128d59fe5b36c12a645a9e46515f982709 Mon Sep 17 00:00:00 2001
From: Manuel Gunther <siebenkopf@googlemail.com>
Date: Mon, 23 Jan 2017 16:16:43 -0700
Subject: [PATCH] Implemented better behavior when protocol is present; added
 possibility to only have 'dev' group

---
 bob/bio/base/database/filelist/query.py       | 202 ++++++++----------
 .../data/example_filelist2/dev/for_models.lst |  12 ++
 .../data/example_filelist2/dev/for_scores.lst |  14 ++
 bob/bio/base/test/test_filelist.py            |  12 +-
 4 files changed, 125 insertions(+), 115 deletions(-)
 create mode 100644 bob/bio/base/test/data/example_filelist2/dev/for_models.lst
 create mode 100644 bob/bio/base/test/data/example_filelist2/dev/for_scores.lst

diff --git a/bob/bio/base/database/filelist/query.py b/bob/bio/base/database/filelist/query.py
index fe034c02..09ae85f8 100644
--- a/bob/bio/base/database/filelist/query.py
+++ b/bob/bio/base/database/filelist/query.py
@@ -182,81 +182,45 @@ class FileListBioDatabase(ZTBioDatabase):
         # Z-Norm files       format:   filename client_id
         self.m_znorm_filename = znorm_filename if znorm_filename is not None else 'for_znorm.lst'
 
-        # decide, which scoring type we have:
-        if probes_filename is not None and scores_filename is None:
-            self.m_use_dense_probes = True
-        elif probes_filename is None and scores_filename is not None:
-            self.m_use_dense_probes = False
-        elif use_dense_probe_file_list is not None:
-            self.m_use_dense_probes = use_dense_probe_file_list
-        # Then direct path to a given protocol
-        elif os.path.isdir(os.path.join(self.get_base_directory(), self.m_dev_subdir)) or os.path.isfile(
-                os.path.join(self.get_base_directory(), self.m_world_filename)):
-            if os.path.exists(self.get_list_file('dev', 'for_probes')) and not os.path.exists(
-                    self.get_list_file('dev', 'for_scores')):
-                self.m_use_dense_probes = True
-            elif not os.path.exists(self.get_list_file('dev', 'for_probes')) and os.path.exists(
-                    self.get_list_file('dev', 'for_scores')):
-                self.m_use_dense_probes = False
-            else:
-                raise ValueError("Unable to determine, which way of probing should be used. Please specify.")
-        # Then path to a directory that contains several subdirectories (one for each protocol)
-        else:
-            # Look at subdirectories for each protocol
-            protocols = [p for p in os.listdir(self.get_base_directory()) if
-                         os.path.isdir(os.path.join(self.get_base_directory(), p))]
-            if len(protocols) == 0:
-                raise ValueError(
-                    "Unable to determine, which way of probing should be used (no protocol directories found). Please specify.")
-            list_use_dense_probes = []
-            for p in protocols:
-                if os.path.exists(self.get_list_file('dev', 'for_probes', p)) and not os.path.exists(
-                        self.get_list_file('dev', 'for_scores', p)):
-                    use_dense_probes = True
-                elif not os.path.exists(self.get_list_file('dev', 'for_probes', p)) and os.path.exists(
-                        self.get_list_file('dev', 'for_scores', p)):
-                    use_dense_probes = False
-                else:
-                    raise ValueError(
-                        "Unable to determine, which way of probing should be used, looking at the protocol (directory) '%s'. Please specify." % p)
-                list_use_dense_probes.append(use_dense_probes)
-            if len(set(list_use_dense_probes)) == 1:
-                self.m_use_dense_probes = list_use_dense_probes[0]
-            else:
-                raise ValueError(
-                    "Unable to determine, which way of probing should be used, since this is not consistent accross protocols. Please specify.")
+        self.m_use_dense_probe_file_list = use_dense_probe_file_list
 
 
     def _list_reader(self, protocol):
-      if protocol not in self.list_readers:
-        if protocol is not None:
-          protocol_dir = os.path.join(self.get_base_directory(), protocol)
-          if not os.path.isdir(protocol_dir):
-            raise ValueError("The directory %s for the given protocol '%s' does not exist" % (protocol_dir, protocol))
-        self.list_readers[protocol] = ListReader(self.keep_read_lists_in_memory)
+        if protocol not in self.list_readers:
+            if protocol is not None:
+                protocol_dir = os.path.join(self.get_base_directory(), protocol)
+                if not os.path.isdir(protocol_dir):
+                    raise ValueError("The directory %s for the given protocol '%s' does not exist" % (protocol_dir, protocol))
+            self.list_readers[protocol] = ListReader(self.keep_read_lists_in_memory)
 
-      return self.list_readers[protocol]
+        return self.list_readers[protocol]
 
     def _make_bio(self, files):
         return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files]
 
     def all_files(self, groups=['dev']):
-        files = self.objects(groups, self.protocol, None, None, **self.all_files_options)
+        files = self.objects(groups, self.protocol, **self.all_files_options)
         # add all files that belong to the ZT-norm
         for group in groups:
             if group == 'world':
                 continue
             if self.implements_zt(self.protocol, group):
-                files += self.tobjects(group, self.protocol, None)
+                files += self.tobjects(group, self.protocol)
                 files += self.zobjects(group, self.protocol, **self.z_probe_options)
         return self.sort(self._make_bio(files))
 
-    def groups(self, protocol=None):
+    def groups(self, protocol=None, add_world=True, add_subworld=True):
         """This function returns the list of groups for this database.
 
         protocol : str or ``None``
           The protocol for which the groups should be retrieved.
 
+        add_world : bool
+          Add the world groups?
+
+        add_subworld : bool
+          Add the sub-world groups? Only valid, when ``add_world=True``
+
         Returns: a list of groups
         """
 
@@ -267,23 +231,27 @@ class FileListBioDatabase(ZTBioDatabase):
                 groups.append('dev')
             if os.path.isdir(os.path.join(self.get_base_directory(), protocol, self.m_eval_subdir)):
                 groups.append('eval')
-            if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_world_filename)):
-                groups.append('world')
-            if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_1_filename)):
-                groups.append('optional_world_1')
-            if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_2_filename)):
-                groups.append('optional_world_2')
+            if add_world:
+                if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_world_filename)):
+                    groups.append('world')
+            if add_world and add_subworld:
+                if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_1_filename)):
+                    groups.append('optional_world_1')
+                if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_2_filename)):
+                    groups.append('optional_world_2')
         else:
             if os.path.isdir(os.path.join(self.get_base_directory(), self.m_dev_subdir)):
                 groups.append('dev')
             if os.path.isdir(os.path.join(self.get_base_directory(), self.m_eval_subdir)):
                 groups.append('eval')
-            if os.path.isfile(os.path.join(self.get_base_directory(), self.m_world_filename)):
-                groups.append('world')
-            if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_1_filename)):
-                groups.append('optional_world_1')
-            if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_2_filename)):
-                groups.append('optional_world_2')
+            if add_world:
+                if os.path.isfile(os.path.join(self.get_base_directory(), self.m_world_filename)):
+                    groups.append('world')
+            if add_world and add_subworld:
+                if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_1_filename)):
+                    groups.append('optional_world_1')
+                if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_2_filename)):
+                    groups.append('optional_world_2')
         return groups
 
     def implements_zt(self, protocol=None, groups=None):
@@ -300,16 +268,36 @@ class FileListBioDatabase(ZTBioDatabase):
         Returns:
           ``True`` if the all file lists for ZT score normalization exist, otherwise ``False``.
         """
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
-
         protocol = protocol or self.protocol
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
+
         for group in groups:
             for t in ['for_tnorm', 'for_znorm']:
-                if not os.path.exists(self.get_list_file(group, t, protocol)):
+                if not os.path.exists(self._get_list_file(group, t, protocol)):
                     return False
         # all files exist
         return True
 
+    def uses_dense_probe_file(self, protocol):
+        """Determines if a dense probe file list is used based on the existence of parameters."""
+        # return, whatever was specified in constructor, if not None
+        if self.m_use_dense_probe_file_list is not None:
+            return self.m_use_dense_probe_file_list
+
+        # check the existence of the files
+        probes = True
+        scores = True
+        for group in self.groups(protocol, add_world=False):
+            probes = probes and os.path.exists(self._get_list_file(group, type='for_probes', protocol=protocol))
+            scores = probes and os.path.exists(self._get_list_file(group, type='for_scores', protocol=protocol))
+        # decide, which score files are available
+        if probes and not scores:
+            return True
+        if not probes and scores:
+            return False
+        raise ValueError("Unable to determine, which way of probing should be used. Please specify.")
+
+
     def get_base_directory(self):
         """Returns the base directory where the filelists defining the database
            are located."""
@@ -322,7 +310,7 @@ class FileListBioDatabase(ZTBioDatabase):
         if not os.path.isdir(self.filelists_directory):
             raise RuntimeError('Invalid directory specified %s.' % (self.filelists_directory))
 
-    def get_list_file(self, group, type=None, protocol=None):
+    def _get_list_file(self, group, type=None, protocol=None):
         if protocol:
             base_directory = os.path.join(self.get_base_directory(), protocol)
         else:
@@ -361,15 +349,13 @@ class FileListBioDatabase(ZTBioDatabase):
 
         Returns: The client id for the given model id, if found.
         """
-        # compatibility reasons
-        groups = group
-        groups = self.check_parameters_for_validity(groups, "group",
-                                                    ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
-                                                    default_parameters=('dev', 'eval', 'world'))
-
         protocol = self.protocol
+        groups = self.check_parameters_for_validity(group, "group",
+                                                    self.groups(protocol),
+                                                    default_parameters=self.groups(protocol, add_subworld=False))
+
         for group in groups:
-            model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_models', protocol), group,
+            model_dict = self._list_reader(protocol).read_models(self._get_list_file(group, 'for_models', protocol), group,
                                                         'for_models')
             if model_id in model_dict:
                 return model_dict[model_id]
@@ -394,12 +380,11 @@ class FileListBioDatabase(ZTBioDatabase):
 
         Returns: The client id for the given model id of a T-Norm model, if found.
         """
-        groups = group
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
-
         protocol = self.protocol
+        groups = self.check_parameters_for_validity(group, "group", self.groups(protocol, add_world=False))
+
         for group in groups:
-            model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_tnorm', protocol), group,
+            model_dict = self._list_reader(protocol).read_models(self._get_list_file(group, 'for_tnorm', protocol), group,
                                                         'for_tnorm')
             if t_model_id in model_dict:
                 return model_dict[t_model_id]
@@ -464,7 +449,7 @@ class FileListBioDatabase(ZTBioDatabase):
         protocol = protocol or self.protocol
         # read all lists for all groups and extract the model ids
         for group in groups:
-            files = self._list_reader(protocol).read_list(self.get_list_file(group, type, protocol), group, type)
+            files = self._list_reader(protocol).read_list(self._get_list_file(group, type, protocol), group, type)
             for file in files:
                 ids.add(file.client_id)
         return ids
@@ -485,8 +470,8 @@ class FileListBioDatabase(ZTBioDatabase):
 
         protocol = protocol or self.protocol
         groups = self.check_parameters_for_validity(groups, "group",
-                                                    ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
-                                                    default_parameters=('dev', 'eval', 'world'))
+                                                    self.groups(protocol),
+                                                    default_parameters=self.groups(protocol, add_subworld=False))
 
         return self.__client_id_list__(groups, 'for_models', protocol)
 
@@ -505,7 +490,7 @@ class FileListBioDatabase(ZTBioDatabase):
         """
 
         protocol = protocol or self.protocol
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
 
         return self.__client_id_list__(groups, 'for_tnorm', protocol)
 
@@ -524,7 +509,7 @@ class FileListBioDatabase(ZTBioDatabase):
         """
 
         protocol = protocol or self.protocol
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
 
         return self.__client_id_list__(groups, 'for_znorm', protocol)
 
@@ -533,7 +518,7 @@ class FileListBioDatabase(ZTBioDatabase):
         protocol = protocol or self.protocol
         # read all lists for all groups and extract the model ids
         for group in groups:
-            dict = self._list_reader(protocol).read_models(self.get_list_file(group, type, protocol), group, type)
+            dict = self._list_reader(protocol).read_models(self._get_list_file(group, type, protocol), group, type)
             ids.update(dict.keys())
         return list(ids)
 
@@ -551,10 +536,7 @@ class FileListBioDatabase(ZTBioDatabase):
         Returns: A list containing all the model ids which have the given properties.
         """
         protocol = protocol or self.protocol
-
-        groups = self.check_parameters_for_validity(groups, "group",
-                                                    ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
-                                                    default_parameters=('dev', 'eval', 'world'))
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol=protocol))
 
         return self.__model_id_list__(groups, 'for_models', protocol)
 
@@ -572,8 +554,7 @@ class FileListBioDatabase(ZTBioDatabase):
         Returns: A list containing all the T-Norm model ids belonging to the given group.
         """
         protocol = protocol or self.protocol
-
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
 
         return self.__model_id_list__(groups, 'for_tnorm', protocol)
 
@@ -598,8 +579,7 @@ class FileListBioDatabase(ZTBioDatabase):
 
         groups : str or [str] or ``None``
           One of the groups ("dev", "eval", "world", "optional_world_1", "optional_world_2") or a tuple with several of them.
-          If 'None' is given (this is the default), it is considered the same as a
-          tuple with all possible values.
+          If 'None' is given (this is the default), it is considered to be the existing subset of ``("world", "dev", "eval")``.
 
         classes : str or [str] or ``None``
           The classes (types of accesses) to be retrieved ('client', 'impostor')
@@ -611,13 +591,13 @@ class FileListBioDatabase(ZTBioDatabase):
         """
 
         protocol = protocol or self.protocol
-        if self.m_use_dense_probes and classes is not None:
+        if self.uses_dense_probe_file(protocol) and classes is not None:
             raise ValueError("To be able to use the 'classes' keyword, please use the 'for_scores.lst' list file.")
 
         purposes = self.check_parameters_for_validity(purposes, "purpose", ('enroll', 'probe'))
         groups = self.check_parameters_for_validity(groups, "group",
-                                                    ('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
-                                                    default_parameters=('dev', 'eval', 'world'))
+                                                    self.groups(protocol),
+                                                    default_parameters=self.groups(protocol, add_subworld=False))
         classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor'))
 
         if isinstance(model_ids, six.string_types):
@@ -627,29 +607,26 @@ class FileListBioDatabase(ZTBioDatabase):
         lists = []
         probe_lists = []
         if 'world' in groups:
-            lists.append(self._list_reader(protocol).read_list(self.get_list_file('world', protocol=protocol), 'world'))
+            lists.append(self._list_reader(protocol).read_list(self._get_list_file('world', protocol=protocol), 'world'))
         if 'optional_world_1' in groups:
-            lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_1', protocol=protocol),
+            lists.append(self._list_reader(protocol).read_list(self._get_list_file('optional_world_1', protocol=protocol),
                                                       'optional_world_1'))
         if 'optional_world_2' in groups:
-            lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_2', protocol=protocol),
+            lists.append(self._list_reader(protocol).read_list(self._get_list_file('optional_world_2', protocol=protocol),
                                                       'optional_world_2'))
 
         for group in ('dev', 'eval'):
             if group in groups:
                 if 'enroll' in purposes:
                     lists.append(
-                        self._list_reader(protocol).read_list(self.get_list_file(group, 'for_models', protocol=protocol), group,
-                                                     'for_models'))
+                        self._list_reader(protocol).read_list(self._get_list_file(group, 'for_models', protocol=protocol), group, 'for_models'))
                 if 'probe' in purposes:
-                    if self.m_use_dense_probes:
+                    if self.uses_dense_probe_file(protocol):
                         probe_lists.append(
-                            self._list_reader(protocol).read_list(self.get_list_file(group, 'for_probes', protocol=protocol),
-                                                         group, 'for_probes'))
+                            self._list_reader(protocol).read_list(self._get_list_file(group, 'for_probes', protocol=protocol), group, 'for_probes'))
                     else:
                         probe_lists.append(
-                            self._list_reader(protocol).read_list(self.get_list_file(group, 'for_scores', protocol=protocol),
-                                                         group, 'for_scores'))
+                            self._list_reader(protocol).read_list(self._get_list_file(group, 'for_scores', protocol=protocol), group, 'for_scores'))
 
         # now, go through the lists and filter the elements
 
@@ -668,7 +645,7 @@ class FileListBioDatabase(ZTBioDatabase):
 
         # probe files; filter by model id and by class
         for list in probe_lists:
-            if self.m_use_dense_probes:
+            if self.uses_dense_probe_file(protocol):
                 # dense probing is used; do not filter over the model ids and not over the classes
                 # -> just add all probe files
                 for file in list:
@@ -710,8 +687,7 @@ class FileListBioDatabase(ZTBioDatabase):
         Returns: A list of :py:class:`BioFile` objects considering all the filtering criteria.
         """
         protocol = protocol or self.protocol
-
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
 
         if (isinstance(model_ids, six.string_types)):
             model_ids = (model_ids,)
@@ -720,7 +696,7 @@ class FileListBioDatabase(ZTBioDatabase):
         # we assume that there is no duplicate file here...
         retval = []
         for group in groups:
-            for file in self._list_reader(protocol).read_list(self.get_list_file(group, 'for_tnorm', protocol), group,
+            for file in self._list_reader(protocol).read_list(self._get_list_file(group, 'for_tnorm', protocol), group,
                                                      'for_tnorm'):
                 if model_ids is None or file._model_id in model_ids:
                     retval.append(file)
@@ -742,14 +718,14 @@ class FileListBioDatabase(ZTBioDatabase):
         """
 
         protocol = protocol or self.protocol
-        groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
+        groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
 
         # iterate over the lists and extract the files
         # we assume that there is no duplicate file here...
         retval = []
         for group in groups:
             retval.extend([file for file in
-                           self._list_reader(protocol).read_list(self.get_list_file(group, 'for_znorm', protocol), group,
+                           self._list_reader(protocol).read_list(self._get_list_file(group, 'for_znorm', protocol), group,
                                                         'for_znorm')])
 
         return self._make_bio(retval)
diff --git a/bob/bio/base/test/data/example_filelist2/dev/for_models.lst b/bob/bio/base/test/data/example_filelist2/dev/for_models.lst
new file mode 100644
index 00000000..7bdbf751
--- /dev/null
+++ b/bob/bio/base/test/data/example_filelist2/dev/for_models.lst
@@ -0,0 +1,12 @@
+data/model3_session1_sample1 3 3
+data/model3_session1_sample2 3 3
+data/model3_session1_sample3 3 3
+data/model3_session2_sample1 3 3
+data/model4_session1_sample1 4 4
+data/model4_session1_sample2 4 4
+data/model4_session1_sample3 4 4
+data/model4_session2_sample1 4 4
+data/model5_session1_sample1 5 5
+data/model5_session1_sample2 5 5
+data/model5_session1_sample3 5 5
+data/model5_session2_sample1 5 5
diff --git a/bob/bio/base/test/data/example_filelist2/dev/for_scores.lst b/bob/bio/base/test/data/example_filelist2/dev/for_scores.lst
new file mode 100644
index 00000000..340e7ec1
--- /dev/null
+++ b/bob/bio/base/test/data/example_filelist2/dev/for_scores.lst
@@ -0,0 +1,14 @@
+data/model3_session3_sample1 3 3 3
+data/model3_session3_sample2 3 3 3
+data/model3_session3_sample3 3 3 3
+data/model3_session4_sample1 3 3 3
+data/model4_session3_sample1 3 3 4
+data/model4_session3_sample2 3 3 4
+data/model4_session3_sample1 4 4 4
+data/model4_session3_sample2 4 4 4
+data/model4_session3_sample3 4 4 4
+data/model4_session4_sample1 4 4 4
+data/model3_session3_sample1 4 4 3
+data/model3_session3_sample2 4 4 3
+data/model5_session3_sample1 5 3 5
+data/model5_session3_sample1 5 5 5
diff --git a/bob/bio/base/test/test_filelist.py b/bob/bio/base/test/test_filelist.py
index fcaa4627..405591a8 100644
--- a/bob/bio/base/test/test_filelist.py
+++ b/bob/bio/base/test/test_filelist.py
@@ -125,11 +125,19 @@ def test_query_protocol():
     assert db.client_id_from_model_id('6', group=None) == '6'
     assert db.client_id_from_t_model_id('7', group=None) == '7'
 
-    nose.tools.assert_raises(ValueError, db.objects, protocol='non-existent')
+
+    # check other protocols
+    assert len(db.objects(protocol='non-existent')) == 0
+
+    prot = 'example_filelist2'
+    assert len(db.model_ids_with_protocol(protocol=prot)) == 3  # 3 model ids for dev only
+    nose.tools.assert_raises(ValueError, db.model_ids_with_protocol, protocol=prot, groups='eval') # eval does not exist for this protocol
+    assert len(db.objects(protocol=prot, groups='dev', purposes='enroll')) == 12
+    assert len(db.objects(protocol=prot, groups='dev', purposes='probe')) == 9
 
 
 def test_query_dense():
-    db = FileListBioDatabase(example_dir, 'test', probes_filename='for_probes.lst')
+    db = FileListBioDatabase(example_dir, 'test', use_dense_probe_file_list=True)
 
     assert len(db.objects(groups='world')) == 8  # 8 samples in the world set
 
-- 
GitLab