Skip to content
Snippets Groups Projects
Commit 1fe7ee17 authored by Manuel Günther's avatar Manuel Günther
Browse files

Made FileSelector easier to extend

parent 325b6368
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,7 @@ def _verify(parameters, test_dir, sub_dir, ref_modifier="", score_modifier=('sco
def test_verify_local():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', os.path.join(dummy_dir, 'database.py'),
......@@ -84,7 +84,7 @@ def test_verify_local():
def test_verify_resources():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', 'dummy',
......@@ -103,7 +103,7 @@ def test_verify_resources():
def test_verify_commandline():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', 'bob.bio.base.test.dummy.database.DummyDatabase()',
......@@ -123,7 +123,7 @@ def test_verify_commandline():
@utils.grid_available
def test_verify_parallel():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
test_database = os.path.join(test_dir, "submitted.sql3")
# define dummy parameters
......@@ -146,7 +146,7 @@ def test_verify_parallel():
def test_verify_compressed():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', 'dummy',
......@@ -166,7 +166,7 @@ def test_verify_compressed():
def test_verify_calibrate():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', 'dummy',
......@@ -186,7 +186,7 @@ def test_verify_calibrate():
def test_verify_fileset():
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', os.path.join(dummy_dir, 'database.py'),
......@@ -210,7 +210,7 @@ def test_verify_filelist():
import bob.db.verification.filelist
except ImportError:
raise SkipTest("Skipping test since bob.db.verification.filelist is not available")
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# define dummy parameters
parameters = [
'-d', os.path.join(dummy_dir, 'filelist.py'),
......@@ -280,7 +280,7 @@ def test11_baselines_api(self):
def test15_evaluate(self):
# tests our 'evaluate' script using the reference files
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
reference_files = ('scores-nonorm-dev', 'scores-ztnorm-dev')
plots = [os.path.join(test_dir, '%s.pdf')%f for f in ['roc', 'cmc', 'det']]
parameters = [
......@@ -305,7 +305,7 @@ def test15_evaluate(self):
def test16_collect_results(self):
# simply test that the collect_results script works
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
from facereclib.script.collect_results import main
main(['--directory', test_dir, '--sort', '--sort-key', 'dir', '--criterion', 'FAR', '--self-test'])
os.rmdir(test_dir)
......@@ -313,7 +313,7 @@ def test16_collect_results(self):
def test21_parameter_script(self):
self.grid_available()
test_dir = tempfile.mkdtemp(prefix='frltest_')
test_dir = tempfile.mkdtemp(prefix='bobtest_')
# tests that the parameter_test.py script works properly
# first test without grid option
......
......@@ -27,19 +27,23 @@ class FileSelector:
"""Initialize the file selector object with the current configuration."""
self.database = database
self.original_directory = database.original_directory
self.preprocessed_directory = preprocessed_directory
self.extractor_file = extractor_file
self.extracted_directory = extracted_directory
self.projector_file = projector_file
self.projected_directory = projected_directory
self.enroller_file = enroller_file
self.model_directories = model_directories
self.score_directories = score_directories
self.zt_score_directories = zt_score_directories
self.default_extension = default_extension
self.compressed_extension = compressed_extension
self.directories = {
'original' : database.original_directory,
'preprocessed' : preprocessed_directory,
'extracted' : extracted_directory,
'projected' : projected_directory
}
def uses_probe_file_sets(self):
"""Returns true if the given protocol enables several probe files for scoring."""
......@@ -47,13 +51,9 @@ class FileSelector:
def get_paths(self, files, directory_type = None):
"""Returns the list of file names for the given list of File objects."""
if directory_type == 'preprocessed':
directory = self.preprocessed_directory
elif directory_type == 'extracted':
directory = self.extracted_directory
elif directory_type == 'projected':
directory = self.projected_directory
else:
try:
directory = self.directories[directory_type]
except KeyError:
raise ValueError("The given directory type '%s' is not supported." % directory_type)
return self.database.file_names(files, directory, self.default_extension)
......
......@@ -53,7 +53,7 @@ def project(algorithm, extractor, groups = None, indices = None, force=False):
else:
index_range = range(len(feature_files))
logger.info("- Projection: projecting %d features from directory '%s' to directory '%s'", len(index_range), fs.extracted_directory, fs.projected_directory)
logger.info("- Projection: projecting %d features from directory '%s' to directory '%s'", len(index_range), fs.directories['extracted'], fs.directories['projected'])
# extract the features
for i in index_range:
feature_file = str(feature_files[i])
......
......@@ -45,7 +45,7 @@ def extract(extractor, preprocessor, groups=None, indices = None, force=False):
else:
index_range = range(len(data_files))
logger.info("- Extraction: extracting %d features from directory '%s' to directory '%s'", len(index_range), fs.preprocessed_directory, fs.extracted_directory)
logger.info("- Extraction: extracting %d features from directory '%s' to directory '%s'", len(index_range), fs.directories['preprocessed'], fs.directories['extracted'])
for i in index_range:
data_file = str(data_files[i])
feature_file = str(feature_files[i])
......
......@@ -23,8 +23,7 @@ def preprocess(preprocessor, groups=None, indices=None, force=False):
else:
index_range = range(len(data_files))
bob.io.base.create_directories_safe(fs.preprocessed_directory)
logger.info("- Preprocessing: processing %d data files from directory '%s' to directory '%s'", len(index_range), fs.original_directory, fs.preprocessed_directory)
logger.info("- Preprocessing: processing %d data files from directory '%s' to directory '%s'", len(index_range), fs.directories['original'], fs.directories['preprocessed'])
# read annotation files
annotation_list = fs.annotation_list(groups=groups)
......
......@@ -45,7 +45,7 @@ def open_compressed(filename, open_flag = 'r', compression_type='bz2'):
In any case, the opened HDF5File is returned, which needs to be closed using the close_compressed() function.
"""
# create temporary HDF5 file name
hdf5_file_name = tempfile.mkstemp('.hdf5', 'frl_')[1]
hdf5_file_name = tempfile.mkstemp('.hdf5', 'bob_')[1]
if open_flag == 'r':
# extract the HDF5 file from the given file name into a temporary file name
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment