Commit 72c66d7b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Merge branch '67-zt-files-are-processed-even-when-no-zt-processing-is-wanted' into 'master'

Resolve "ZT files are processed even when no ZT processing is wanted"

Closes #67

See merge request !80
parents 962a8513 82147315
Pipeline #11416 passed with stages
in 22 minutes and 33 seconds
......@@ -370,7 +370,7 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
"""
return sorted(self.model_ids_with_protocol(groups=groups, protocol=self.protocol))
def all_files(self, groups=None):
def all_files(self, groups=None, **kwargs):
"""all_files(groups=None) -> files
Returns all files of the database, respecting the current protocol.
......@@ -382,6 +382,8 @@ class BioDatabase(six.with_metaclass(abc.ABCMeta, bob.db.base.Database)):
The groups to get the data for.
If ``None``, data for all groups is returned.
kwargs: ignored
**Returns:**
files : [:py:class:`bob.bio.base.database.BioFile`]
......@@ -640,7 +642,7 @@ class ZTBioDatabase(BioDatabase):
"""
raise NotImplementedError("This function must be implemented in your derived class.")
def all_files(self, groups=['dev']):
def all_files(self, groups=['dev'], add_zt_files=True):
"""all_files(groups=None) -> files
Returns all files of the database, including those for ZT norm, respecting the current protocol.
......@@ -652,6 +654,9 @@ class ZTBioDatabase(BioDatabase):
The groups to get the data for.
If ``None``, data for all groups is returned.
add_zt_files: bool
If set (the default), files for ZT score normalization are added.
**Returns:**
files : [:py:class:`bob.bio.base.database.BioFile`]
......@@ -660,11 +665,12 @@ class ZTBioDatabase(BioDatabase):
files = self.objects(protocol=self.protocol, groups=groups, **self.all_files_options)
# add all files that belong to the ZT-norm
for group in groups:
if group == 'world':
continue
files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None)
files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options)
if add_zt_files:
for group in groups:
if group == 'world':
continue
files += self.tobjects(protocol=self.protocol, groups=group, model_ids=None)
files += self.zobjects(protocol=self.protocol, groups=group, **self.z_probe_options)
return self.sort(files)
@abc.abstractmethod
......
......@@ -198,13 +198,15 @@ class FileListBioDatabase(ZTBioDatabase):
def _make_bio(self, files):
return [self.bio_file_class(client_id=f.client_id, path=f.path, file_id=f.id) for f in files]
def all_files(self, groups=['dev']):
def all_files(self, groups=['dev'], add_zt_files=True):
files = self.objects(groups, self.protocol, **self.all_files_options)
# add all files that belong to the ZT-norm
for group in groups:
if group == 'world':
continue
if self.implements_zt(self.protocol, group):
if add_zt_files:
if not self.implements_zt(self.protocol, group):
raise ValueError("ZT score files are requested, but no such files are defined in group %s for protocol %s", group, self.protocol)
files += self.tobjects(group, self.protocol)
files += self.zobjects(group, self.protocol, **self.z_probe_options)
return self.sort(self._make_bio(files))
......
......@@ -21,24 +21,26 @@ regenerate_reference = False
dummy_dir = pkg_resources.resource_filename('bob.bio.base', 'test/dummy')
data_dir = pkg_resources.resource_filename('bob.bio.base', 'test/data')
def _verify(parameters, test_dir, sub_dir, ref_modifier="", score_modifier=('scores',''), counts=3):
def _verify(parameters, test_dir, sub_dir, ref_modifier="", score_modifier=('scores',''), counts=3, check_zt=True):
from bob.bio.base.script.verify import main
try:
main(parameters)
Range = (0,1) if check_zt else (0,)
# assert that the score file exists
score_files = [os.path.join(test_dir, sub_dir, 'Default', norm, '%s-dev%s'%score_modifier) for norm in ('nonorm', 'ztnorm')]
assert os.path.exists(score_files[0]), "Score file %s does not exist" % score_files[0]
assert os.path.exists(score_files[1]), "Score file %s does not exist" % score_files[1]
for i in Range:
assert os.path.exists(score_files[i]), "Score file %s does not exist" % score_files[i]
# also assert that the scores are still the same -- though they have no real meaning
reference_files = [os.path.join(data_dir, 'scores-%s%s-dev'%(norm, ref_modifier)) for norm in ('nonorm', 'ztnorm')]
if regenerate_reference:
for i in (0,1):
for i in Range:
shutil.copy(score_files[i], reference_files[i])
for i in (0,1):
for i in Range:
d = []
# read reference and new data
for score_file in (score_files[i], reference_files[i]):
......@@ -95,6 +97,23 @@ def test_verify_algorithm_noprojection():
_verify(parameters, test_dir, 'algorithm_noprojection')
def test_verify_no_ztnorm():
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', os.path.join(dummy_dir, 'database.py'),
'-p', os.path.join(dummy_dir, 'preprocessor.py'),
'-e', os.path.join(dummy_dir, 'extractor.py'),
'-a', os.path.join(dummy_dir, 'algorithm_noprojection.py'),
'-vs', 'test_nozt',
'--temp-directory', test_dir,
'--result-directory', test_dir
]
_verify(parameters, test_dir, 'test_nozt', check_zt=False)
def test_verify_resources():
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
......
......@@ -68,7 +68,8 @@ class FileSelector(object):
score_directories,
zt_score_directories = None,
default_extension = '.hdf5',
compressed_extension = ''
compressed_extension = '',
zt_norm = False
):
"""Initialize the file selector object with the current configuration."""
......@@ -89,6 +90,7 @@ class FileSelector(object):
'extracted' : extracted_directory,
'projected' : projected_directory
}
self.zt_norm = zt_norm
def uses_probe_file_sets(self):
......@@ -108,7 +110,7 @@ class FileSelector(object):
### List of files that will be used for all files
def original_data_list(self, groups = None):
"""Returns the list of original ``BioFile`` objects that can be used for preprocessing."""
return self.database.all_files(groups=groups)
return self.database.all_files(groups=groups,add_zt_files=self.zt_norm)
def original_directory_and_extension(self):
"""Returns the directory and extension of the original files."""
......@@ -116,7 +118,7 @@ class FileSelector(object):
def annotation_list(self, groups = None):
"""Returns the list of annotations objects."""
return self.database.all_files(groups=groups)
return self.database.all_files(groups=groups,add_zt_files=self.zt_norm)
def get_annotations(self, annotation_file):
"""Returns the annotations of the given file."""
......@@ -124,15 +126,15 @@ class FileSelector(object):
def preprocessed_data_list(self, groups = None):
"""Returns the list of preprocessed data files."""
return self.get_paths(self.database.all_files(groups=groups), "preprocessed")
return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "preprocessed")
def feature_list(self, groups = None):
"""Returns the list of extracted feature files."""
return self.get_paths(self.database.all_files(groups=groups), "extracted")
return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "extracted")
def projected_list(self, groups = None):
"""Returns the list of projected feature files."""
return self.get_paths(self.database.all_files(groups=groups), "projected")
return self.get_paths(self.database.all_files(groups=groups,add_zt_files=self.zt_norm), "projected")
### Training lists
......
......@@ -460,6 +460,7 @@ def initialize(parsers, command_line_parameters=None, skips=[]):
zt_score_directories=[os.path.join(args.temp_directory, protocol, s) for s in args.zt_directories],
compressed_extension='.tar.bz2' if args.write_compressed_score_files else '',
default_extension='.hdf5',
zt_norm = args.zt_norm
)
return args
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment