Skip to content
Snippets Groups Projects
Commit 21954612 authored by Manuel Günther's avatar Manuel Günther
Browse files

Implemented better behavior when protocol is present; added possibility to only have 'dev' group

parent f7cb7c9f
Branches
Tags
1 merge request!59Resolve "strange behaviour when retrieving objects from bob.bio.base filelist database query"
Pipeline #
......@@ -182,49 +182,7 @@ class FileListBioDatabase(ZTBioDatabase):
# Z-Norm files format: filename client_id
self.m_znorm_filename = znorm_filename if znorm_filename is not None else 'for_znorm.lst'
# decide, which scoring type we have:
if probes_filename is not None and scores_filename is None:
self.m_use_dense_probes = True
elif probes_filename is None and scores_filename is not None:
self.m_use_dense_probes = False
elif use_dense_probe_file_list is not None:
self.m_use_dense_probes = use_dense_probe_file_list
# Then direct path to a given protocol
elif os.path.isdir(os.path.join(self.get_base_directory(), self.m_dev_subdir)) or os.path.isfile(
os.path.join(self.get_base_directory(), self.m_world_filename)):
if os.path.exists(self.get_list_file('dev', 'for_probes')) and not os.path.exists(
self.get_list_file('dev', 'for_scores')):
self.m_use_dense_probes = True
elif not os.path.exists(self.get_list_file('dev', 'for_probes')) and os.path.exists(
self.get_list_file('dev', 'for_scores')):
self.m_use_dense_probes = False
else:
raise ValueError("Unable to determine, which way of probing should be used. Please specify.")
# Then path to a directory that contains several subdirectories (one for each protocol)
else:
# Look at subdirectories for each protocol
protocols = [p for p in os.listdir(self.get_base_directory()) if
os.path.isdir(os.path.join(self.get_base_directory(), p))]
if len(protocols) == 0:
raise ValueError(
"Unable to determine, which way of probing should be used (no protocol directories found). Please specify.")
list_use_dense_probes = []
for p in protocols:
if os.path.exists(self.get_list_file('dev', 'for_probes', p)) and not os.path.exists(
self.get_list_file('dev', 'for_scores', p)):
use_dense_probes = True
elif not os.path.exists(self.get_list_file('dev', 'for_probes', p)) and os.path.exists(
self.get_list_file('dev', 'for_scores', p)):
use_dense_probes = False
else:
raise ValueError(
"Unable to determine, which way of probing should be used, looking at the protocol (directory) '%s'. Please specify." % p)
list_use_dense_probes.append(use_dense_probes)
if len(set(list_use_dense_probes)) == 1:
self.m_use_dense_probes = list_use_dense_probes[0]
else:
raise ValueError(
"Unable to determine, which way of probing should be used, since this is not consistent accross protocols. Please specify.")
self.m_use_dense_probe_file_list = use_dense_probe_file_list
def _list_reader(self, protocol):
......@@ -241,22 +199,28 @@ class FileListBioDatabase(ZTBioDatabase):
return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files]
def all_files(self, groups=['dev']):
files = self.objects(groups, self.protocol, None, None, **self.all_files_options)
files = self.objects(groups, self.protocol, **self.all_files_options)
# add all files that belong to the ZT-norm
for group in groups:
if group == 'world':
continue
if self.implements_zt(self.protocol, group):
files += self.tobjects(group, self.protocol, None)
files += self.tobjects(group, self.protocol)
files += self.zobjects(group, self.protocol, **self.z_probe_options)
return self.sort(self._make_bio(files))
def groups(self, protocol=None):
def groups(self, protocol=None, add_world=True, add_subworld=True):
"""This function returns the list of groups for this database.
protocol : str or ``None``
The protocol for which the groups should be retrieved.
add_world : bool
Add the world groups?
add_subworld : bool
Add the sub-world groups? Only valid, when ``add_world=True``
Returns: a list of groups
"""
......@@ -267,8 +231,10 @@ class FileListBioDatabase(ZTBioDatabase):
groups.append('dev')
if os.path.isdir(os.path.join(self.get_base_directory(), protocol, self.m_eval_subdir)):
groups.append('eval')
if add_world:
if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_world_filename)):
groups.append('world')
if add_world and add_subworld:
if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_1_filename)):
groups.append('optional_world_1')
if os.path.isfile(os.path.join(self.get_base_directory(), protocol, self.m_optional_world_2_filename)):
......@@ -278,8 +244,10 @@ class FileListBioDatabase(ZTBioDatabase):
groups.append('dev')
if os.path.isdir(os.path.join(self.get_base_directory(), self.m_eval_subdir)):
groups.append('eval')
if add_world:
if os.path.isfile(os.path.join(self.get_base_directory(), self.m_world_filename)):
groups.append('world')
if add_world and add_subworld:
if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_1_filename)):
groups.append('optional_world_1')
if os.path.isfile(os.path.join(self.get_base_directory(), self.m_optional_world_2_filename)):
......@@ -300,16 +268,36 @@ class FileListBioDatabase(ZTBioDatabase):
Returns:
``True`` if the all file lists for ZT score normalization exist, otherwise ``False``.
"""
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
for group in groups:
for t in ['for_tnorm', 'for_znorm']:
if not os.path.exists(self.get_list_file(group, t, protocol)):
if not os.path.exists(self._get_list_file(group, t, protocol)):
return False
# all files exist
return True
def uses_dense_probe_file(self, protocol):
"""Determines if a dense probe file list is used based on the existence of parameters."""
# return, whatever was specified in constructor, if not None
if self.m_use_dense_probe_file_list is not None:
return self.m_use_dense_probe_file_list
# check the existence of the files
probes = True
scores = True
for group in self.groups(protocol, add_world=False):
probes = probes and os.path.exists(self._get_list_file(group, type='for_probes', protocol=protocol))
scores = probes and os.path.exists(self._get_list_file(group, type='for_scores', protocol=protocol))
# decide, which score files are available
if probes and not scores:
return True
if not probes and scores:
return False
raise ValueError("Unable to determine, which way of probing should be used. Please specify.")
def get_base_directory(self):
"""Returns the base directory where the filelists defining the database
are located."""
......@@ -322,7 +310,7 @@ class FileListBioDatabase(ZTBioDatabase):
if not os.path.isdir(self.filelists_directory):
raise RuntimeError('Invalid directory specified %s.' % (self.filelists_directory))
def get_list_file(self, group, type=None, protocol=None):
def _get_list_file(self, group, type=None, protocol=None):
if protocol:
base_directory = os.path.join(self.get_base_directory(), protocol)
else:
......@@ -361,15 +349,13 @@ class FileListBioDatabase(ZTBioDatabase):
Returns: The client id for the given model id, if found.
"""
# compatibility reasons
groups = group
groups = self.check_parameters_for_validity(groups, "group",
('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
default_parameters=('dev', 'eval', 'world'))
protocol = self.protocol
groups = self.check_parameters_for_validity(group, "group",
self.groups(protocol),
default_parameters=self.groups(protocol, add_subworld=False))
for group in groups:
model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_models', protocol), group,
model_dict = self._list_reader(protocol).read_models(self._get_list_file(group, 'for_models', protocol), group,
'for_models')
if model_id in model_dict:
return model_dict[model_id]
......@@ -394,12 +380,11 @@ class FileListBioDatabase(ZTBioDatabase):
Returns: The client id for the given model id of a T-Norm model, if found.
"""
groups = group
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
protocol = self.protocol
groups = self.check_parameters_for_validity(group, "group", self.groups(protocol, add_world=False))
for group in groups:
model_dict = self._list_reader(protocol).read_models(self.get_list_file(group, 'for_tnorm', protocol), group,
model_dict = self._list_reader(protocol).read_models(self._get_list_file(group, 'for_tnorm', protocol), group,
'for_tnorm')
if t_model_id in model_dict:
return model_dict[t_model_id]
......@@ -464,7 +449,7 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = protocol or self.protocol
# read all lists for all groups and extract the model ids
for group in groups:
files = self._list_reader(protocol).read_list(self.get_list_file(group, type, protocol), group, type)
files = self._list_reader(protocol).read_list(self._get_list_file(group, type, protocol), group, type)
for file in files:
ids.add(file.client_id)
return ids
......@@ -485,8 +470,8 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group",
('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
default_parameters=('dev', 'eval', 'world'))
self.groups(protocol),
default_parameters=self.groups(protocol, add_subworld=False))
return self.__client_id_list__(groups, 'for_models', protocol)
......@@ -505,7 +490,7 @@ class FileListBioDatabase(ZTBioDatabase):
"""
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
return self.__client_id_list__(groups, 'for_tnorm', protocol)
......@@ -524,7 +509,7 @@ class FileListBioDatabase(ZTBioDatabase):
"""
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
return self.__client_id_list__(groups, 'for_znorm', protocol)
......@@ -533,7 +518,7 @@ class FileListBioDatabase(ZTBioDatabase):
protocol = protocol or self.protocol
# read all lists for all groups and extract the model ids
for group in groups:
dict = self._list_reader(protocol).read_models(self.get_list_file(group, type, protocol), group, type)
dict = self._list_reader(protocol).read_models(self._get_list_file(group, type, protocol), group, type)
ids.update(dict.keys())
return list(ids)
......@@ -551,10 +536,7 @@ class FileListBioDatabase(ZTBioDatabase):
Returns: A list containing all the model ids which have the given properties.
"""
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group",
('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
default_parameters=('dev', 'eval', 'world'))
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol=protocol))
return self.__model_id_list__(groups, 'for_models', protocol)
......@@ -572,8 +554,7 @@ class FileListBioDatabase(ZTBioDatabase):
Returns: A list containing all the T-Norm model ids belonging to the given group.
"""
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
return self.__model_id_list__(groups, 'for_tnorm', protocol)
......@@ -598,8 +579,7 @@ class FileListBioDatabase(ZTBioDatabase):
groups : str or [str] or ``None``
One of the groups ("dev", "eval", "world", "optional_world_1", "optional_world_2") or a tuple with several of them.
If 'None' is given (this is the default), it is considered the same as a
tuple with all possible values.
If 'None' is given (this is the default), it is considered to be the existing subset of ``("world", "dev", "eval")``.
classes : str or [str] or ``None``
The classes (types of accesses) to be retrieved ('client', 'impostor')
......@@ -611,13 +591,13 @@ class FileListBioDatabase(ZTBioDatabase):
"""
protocol = protocol or self.protocol
if self.m_use_dense_probes and classes is not None:
if self.uses_dense_probe_file(protocol) and classes is not None:
raise ValueError("To be able to use the 'classes' keyword, please use the 'for_scores.lst' list file.")
purposes = self.check_parameters_for_validity(purposes, "purpose", ('enroll', 'probe'))
groups = self.check_parameters_for_validity(groups, "group",
('dev', 'eval', 'world', 'optional_world_1', 'optional_world_2'),
default_parameters=('dev', 'eval', 'world'))
self.groups(protocol),
default_parameters=self.groups(protocol, add_subworld=False))
classes = self.check_parameters_for_validity(classes, "class", ('client', 'impostor'))
if isinstance(model_ids, six.string_types):
......@@ -627,29 +607,26 @@ class FileListBioDatabase(ZTBioDatabase):
lists = []
probe_lists = []
if 'world' in groups:
lists.append(self._list_reader(protocol).read_list(self.get_list_file('world', protocol=protocol), 'world'))
lists.append(self._list_reader(protocol).read_list(self._get_list_file('world', protocol=protocol), 'world'))
if 'optional_world_1' in groups:
lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_1', protocol=protocol),
lists.append(self._list_reader(protocol).read_list(self._get_list_file('optional_world_1', protocol=protocol),
'optional_world_1'))
if 'optional_world_2' in groups:
lists.append(self._list_reader(protocol).read_list(self.get_list_file('optional_world_2', protocol=protocol),
lists.append(self._list_reader(protocol).read_list(self._get_list_file('optional_world_2', protocol=protocol),
'optional_world_2'))
for group in ('dev', 'eval'):
if group in groups:
if 'enroll' in purposes:
lists.append(
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_models', protocol=protocol), group,
'for_models'))
self._list_reader(protocol).read_list(self._get_list_file(group, 'for_models', protocol=protocol), group, 'for_models'))
if 'probe' in purposes:
if self.m_use_dense_probes:
if self.uses_dense_probe_file(protocol):
probe_lists.append(
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_probes', protocol=protocol),
group, 'for_probes'))
self._list_reader(protocol).read_list(self._get_list_file(group, 'for_probes', protocol=protocol), group, 'for_probes'))
else:
probe_lists.append(
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_scores', protocol=protocol),
group, 'for_scores'))
self._list_reader(protocol).read_list(self._get_list_file(group, 'for_scores', protocol=protocol), group, 'for_scores'))
# now, go through the lists and filter the elements
......@@ -668,7 +645,7 @@ class FileListBioDatabase(ZTBioDatabase):
# probe files; filter by model id and by class
for list in probe_lists:
if self.m_use_dense_probes:
if self.uses_dense_probe_file(protocol):
# dense probing is used; do not filter over the model ids and not over the classes
# -> just add all probe files
for file in list:
......@@ -710,8 +687,7 @@ class FileListBioDatabase(ZTBioDatabase):
Returns: A list of :py:class:`BioFile` objects considering all the filtering criteria.
"""
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
if (isinstance(model_ids, six.string_types)):
model_ids = (model_ids,)
......@@ -720,7 +696,7 @@ class FileListBioDatabase(ZTBioDatabase):
# we assume that there is no duplicate file here...
retval = []
for group in groups:
for file in self._list_reader(protocol).read_list(self.get_list_file(group, 'for_tnorm', protocol), group,
for file in self._list_reader(protocol).read_list(self._get_list_file(group, 'for_tnorm', protocol), group,
'for_tnorm'):
if model_ids is None or file._model_id in model_ids:
retval.append(file)
......@@ -742,14 +718,14 @@ class FileListBioDatabase(ZTBioDatabase):
"""
protocol = protocol or self.protocol
groups = self.check_parameters_for_validity(groups, "group", ('dev', 'eval'))
groups = self.check_parameters_for_validity(groups, "group", self.groups(protocol, add_world=False))
# iterate over the lists and extract the files
# we assume that there is no duplicate file here...
retval = []
for group in groups:
retval.extend([file for file in
self._list_reader(protocol).read_list(self.get_list_file(group, 'for_znorm', protocol), group,
self._list_reader(protocol).read_list(self._get_list_file(group, 'for_znorm', protocol), group,
'for_znorm')])
return self._make_bio(retval)
......
data/model3_session1_sample1 3 3
data/model3_session1_sample2 3 3
data/model3_session1_sample3 3 3
data/model3_session2_sample1 3 3
data/model4_session1_sample1 4 4
data/model4_session1_sample2 4 4
data/model4_session1_sample3 4 4
data/model4_session2_sample1 4 4
data/model5_session1_sample1 5 5
data/model5_session1_sample2 5 5
data/model5_session1_sample3 5 5
data/model5_session2_sample1 5 5
data/model3_session3_sample1 3 3 3
data/model3_session3_sample2 3 3 3
data/model3_session3_sample3 3 3 3
data/model3_session4_sample1 3 3 3
data/model4_session3_sample1 3 3 4
data/model4_session3_sample2 3 3 4
data/model4_session3_sample1 4 4 4
data/model4_session3_sample2 4 4 4
data/model4_session3_sample3 4 4 4
data/model4_session4_sample1 4 4 4
data/model3_session3_sample1 4 4 3
data/model3_session3_sample2 4 4 3
data/model5_session3_sample1 5 3 5
data/model5_session3_sample1 5 5 5
......@@ -125,11 +125,19 @@ def test_query_protocol():
assert db.client_id_from_model_id('6', group=None) == '6'
assert db.client_id_from_t_model_id('7', group=None) == '7'
nose.tools.assert_raises(ValueError, db.objects, protocol='non-existent')
# check other protocols
assert len(db.objects(protocol='non-existent')) == 0
prot = 'example_filelist2'
assert len(db.model_ids_with_protocol(protocol=prot)) == 3 # 3 model ids for dev only
nose.tools.assert_raises(ValueError, db.model_ids_with_protocol, protocol=prot, groups='eval') # eval does not exist for this protocol
assert len(db.objects(protocol=prot, groups='dev', purposes='enroll')) == 12
assert len(db.objects(protocol=prot, groups='dev', purposes='probe')) == 9
def test_query_dense():
db = FileListBioDatabase(example_dir, 'test', probes_filename='for_probes.lst')
db = FileListBioDatabase(example_dir, 'test', use_dense_probe_file_list=True)
assert len(db.objects(groups='world')) == 8 # 8 samples in the world set
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment