diff --git a/bob/bio/base/test/dummy/algorithm.py b/bob/bio/base/test/dummy/algorithm.py index ab427baeaa2292b6ec43a9fb52103d142866f286..c877cd9a0b6465b3dd893b0a3061b641f667358a 100644 --- a/bob/bio/base/test/dummy/algorithm.py +++ b/bob/bio/base/test/dummy/algorithm.py @@ -2,6 +2,7 @@ import scipy.spatial import bob.io.base import numpy from bob.bio.base.algorithm import Algorithm +from bob.bio.base.database import BioFile _data = [5., 6., 7., 8., 9.] @@ -63,7 +64,7 @@ class DummyAlgorithmMetadata (DummyAlgorithm): def train_projector(self, train_files, projector_file, metadata=None): """Does nothing, simply converts the data type of the data, ignoring any annotation.""" - assert metadata is not None + assert isinstance(metadata, list) return super(DummyAlgorithmMetadata, self).train_projector(train_files, projector_file) def enroll(self, enroll_features, metadata=None): @@ -74,11 +75,11 @@ class DummyAlgorithmMetadata (DummyAlgorithm): def score(self, model, probe, metadata=None): """Returns the Euclidean distance between model and probe""" - assert metadata is not None + assert isinstance(metadata, BioFile) return super(DummyAlgorithmMetadata, self).score(model, probe) def project(self, feature, metadata=None): - assert metadata is not None + assert isinstance(metadata, BioFile) return super(DummyAlgorithmMetadata, self).project(feature) algorithm_metadata = DummyAlgorithmMetadata() diff --git a/bob/bio/base/test/dummy/extractor.py b/bob/bio/base/test/dummy/extractor.py index dbdf0946a01377dfa9db3ff483391710e20d9990..2e53464569b5579765c06554797d3e28a2963ee2 100644 --- a/bob/bio/base/test/dummy/extractor.py +++ b/bob/bio/base/test/dummy/extractor.py @@ -1,6 +1,6 @@ import numpy import bob.bio.base - +from bob.bio.base.database import BioFile from bob.bio.base.extractor import Extractor _data = [0., 1., 2., 3., 4.] @@ -31,8 +31,7 @@ class DummyExtractorMetadata (DummyExtractor): def __call__(self, data, metadata=None): """Does nothing, simply converts the data type of the data, ignoring any annotation.""" - assert metadata is not None - assert self.model - return data.astype(numpy.float).flatten() + assert isinstance(metadata, BioFile) + return super(DummyExtractorMetadata, self).__call__(data) extractor_metadata = DummyExtractorMetadata() diff --git a/bob/bio/base/test/dummy/preprocessor.py b/bob/bio/base/test/dummy/preprocessor.py index 506f0ef881df79e09c12c028174222bcb6f189b9..1c14bb1b1c97fbe72d135947d79da7d697193bef 100644 --- a/bob/bio/base/test/dummy/preprocessor.py +++ b/bob/bio/base/test/dummy/preprocessor.py @@ -1,4 +1,5 @@ from bob.bio.base.preprocessor import Preprocessor +from bob.bio.base.database import BioFile import numpy numpy.random.seed(10) @@ -23,7 +24,7 @@ class DummyPreprocessorMetadata (DummyPreprocessor): def __call__(self, data, annotation, metadata=None): """Does nothing, simply converts the data type of the data, ignoring any annotation.""" - assert metadata is not None + assert isinstance(metadata, BioFile) return super(DummyPreprocessorMetadata, self).__call__(data, annotation) preprocessor_metadata = DummyPreprocessorMetadata() diff --git a/bob/bio/base/tools/algorithm.py b/bob/bio/base/tools/algorithm.py index ff2d69ff3d8553751b9d3bc9b169fc093f3e9c51..2bbaa558352d0d5ba22f6bf2ce1743ebe19524f5 100644 --- a/bob/bio/base/tools/algorithm.py +++ b/bob/bio/base/tools/algorithm.py @@ -55,7 +55,7 @@ def train_projector(algorithm, extractor, allow_missing_files = False, force = F logger.info("- Projection: training projector '%s' using %d training files: ", fs.projector_file, len(train_files)) # perform training - if "metadata" in inspect.getargspec(algorithm.train_projector).args: + if utils.is_argument_available("metadata", algorithm.train_projector): metadata = fs.database.training_files('train_projector', algorithm.split_training_features_by_client) algorithm.train_projector(train_features, fs.projector_file, metadata=metadata) else: @@ -138,7 +138,7 @@ def project(algorithm, extractor, groups = None, indices = None, allow_missing_f # project feature if "metadata" in inspect.getargspec(algorithm.project).args: - projected = algorithm.project(feature, metadata=metadata) + projected = algorithm.project(feature, metadata=metadata[i]) else: projected = algorithm.project(feature) @@ -257,6 +257,9 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev # which tool to use to read the features... reader = algorithm if algorithm.use_projected_features_for_enrollment else extractor + # Checking if we need to ship the metadata to the method enroll + has_metadata = utils.is_argument_available("metadata", algorithm.enroll) + # Create Models if 'N' in types: for group in groups: @@ -290,7 +293,7 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev # load all files into memory enroll_features = [reader.read_feature(enroll_file) for enroll_file in enroll_files] - if "metadata" in inspect.getargspec(algorithm.enroll).args: + if has_metadata: metadata = fs.database.enroll_files(group=group, model_id=model_id) model = algorithm.enroll(enroll_features, metadata=metadata) else: @@ -341,7 +344,7 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev # load all files into memory t_enroll_features = [reader.read_feature(t_enroll_file) for t_enroll_file in t_enroll_files] - if "metadata" in inspect.getargspec(algorithm.enroll).args: + if has_metadata: metadata = fs.database.enroll_files(group=group, model_id=t_model_id) t_model = algorithm.enroll(t_enroll_features, metadata=metadata) else: diff --git a/bob/bio/base/tools/extractor.py b/bob/bio/base/tools/extractor.py index 84705eac55b0cb3f607359ddefbde0c2f40fab4d..7f822c4918cf862360606e05cb1d63ad64f36423 100644 --- a/bob/bio/base/tools/extractor.py +++ b/bob/bio/base/tools/extractor.py @@ -91,7 +91,11 @@ def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_ extractor.load(fs.extractor_file) data_files = fs.preprocessed_data_list(groups=groups) feature_files = fs.feature_list(groups=groups) - metadata = fs.original_data_list(groups=groups) + + if utils.is_argument_available("metadata", extractor.__call__): + metadata = fs.original_data_list(groups=groups) + else: + metadata = None # select a subset of indices to iterate if indices is not None: @@ -120,11 +124,12 @@ def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_ bob.io.base.create_directories_safe(os.path.dirname(feature_file)) # load data data = preprocessor.read_data(data_file) + # extract feature - if "metadata" in inspect.getargspec(extractor.__call__).args: - feature = extractor(data, metadata=metadata[i]) - else: + if metadata is None: feature = extractor(data) + else: + feature = extractor(data, metadata=metadata[i]) if feature is None: if allow_missing_files: diff --git a/bob/bio/base/tools/preprocessor.py b/bob/bio/base/tools/preprocessor.py index b4140dfac86ea14b8f88d51ee54220f307c40020..93be41169d8c5ef32c9514095d224db511dbae15 100644 --- a/bob/bio/base/tools/preprocessor.py +++ b/bob/bio/base/tools/preprocessor.py @@ -46,7 +46,11 @@ def preprocess(preprocessor, groups = None, indices = None, allow_missing_files data_files = fs.original_data_list(groups=groups) original_directory, original_extension = fs.original_directory_and_extension() preprocessed_data_files = fs.preprocessed_data_list(groups=groups) - metadata = fs.original_data_list(groups=groups) + + if utils.is_argument_available("metadata", preprocessor.__call__): + metadata = fs.original_data_list(groups=groups) + else: + metadata = None # select a subset of keys to iterate if indices is not None: @@ -80,10 +84,10 @@ def preprocess(preprocessor, groups = None, indices = None, allow_missing_files annotations = fs.get_annotations(annotation_list[i]) # call the preprocessor - if "metadata" in inspect.getargspec(preprocessor.__call__).args: - preprocessed_data = preprocessor(data, annotations, metadata=metadata[i]) - else: + if metadata is None: preprocessed_data = preprocessor(data, annotations) + else: + preprocessed_data = preprocessor(data, annotations, metadata=metadata[i]) if preprocessed_data is None: if allow_missing_files: diff --git a/bob/bio/base/tools/scoring.py b/bob/bio/base/tools/scoring.py index 58f5d81703498a492ec88e6663a4b0c19f95dce8..18aed256117d7c06dcb5dab84b50c25c5305ded0 100644 --- a/bob/bio/base/tools/scoring.py +++ b/bob/bio/base/tools/scoring.py @@ -27,6 +27,9 @@ def _scores(algorithm, reader, model, probe_objects, allow_missing_files): # if we have no model, all scores are undefined return scores + # Checking if we need to ship the metadata in the scoring method + has_metadata = utils.is_argument_available("metadata", algorithm.score) + # Loops over the probe sets for i, probe_element, probe_metadata in zip(range(len(probes)), probes, probe_objects): if fs.uses_probe_file_sets(): @@ -47,8 +50,9 @@ def _scores(algorithm, reader, model, probe_objects, allow_missing_files): continue # read probe probe = reader.read_feature(probe_element) + # compute score - if "metadata" in inspect.getargspec(algorithm.score).args: + if has_metadata: scores[0, i] = algorithm.score(model, probe, metadata=probe_metadata) else: scores[0, i] = algorithm.score(model, probe) diff --git a/bob/bio/base/utils/__init__.py b/bob/bio/base/utils/__init__.py index b65569cb5106e308123ab5c1a4022ee54a4fbc81..2cd7501a46599d3497387dc3c19ce2c32b3cf01b 100644 --- a/bob/bio/base/utils/__init__.py +++ b/bob/bio/base/utils/__init__.py @@ -7,7 +7,8 @@ from .resources import * from .io import * from .singleton import * from . import processors - +import six +import inspect import numpy def score_fusion_strategy(strategy_name = 'average'): @@ -53,6 +54,27 @@ def selected_elements(list_of_elements, desired_number_of_elements = None): # sub-select return [list_of_elements[i] for i in selected_indices(total_number_of_elements, desired_number_of_elements)] + def pretty_print(obj, kwargs): """Returns a pretty-print of the parameters to the constructor of a class, which should be able to copy-paste on the command line to create the object (with few exceptions).""" return "%s(%s)" % (str(obj.__class__), ", ".join(["%s='%s'" % (key,value) if isinstance(value, str) else "%s=%s" % (key, value) for key,value in kwargs.items() if value is not None])) + + +def is_argument_available(argument, method): + """ + Check if an argument (or keyword argument) is available in a method + + Attributes + ---------- + argument: str + The name of the argument (or keyword argument). + + method: + Pointer to the method + + """ + + if six.PY2: + return argument in inspect.getargspec(method).args + else: + return argument in inspect.signature(method).parameters.keys()