Commit 1f408e55 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created the method is_argument_available that checks if an argument is...

Created the method is_argument_available that checks if an argument is available for certain method. provided both python2 and python3 implementations

Used the new method is_argument_available

Used the new method is_argument_available

Improved the test cases
parent a8746ea4
Pipeline #17206 passed with stage
in 22 minutes and 45 seconds
......@@ -2,6 +2,7 @@ import scipy.spatial
import bob.io.base
import numpy
from bob.bio.base.algorithm import Algorithm
from bob.bio.base.database import BioFile
_data = [5., 6., 7., 8., 9.]
......@@ -63,7 +64,7 @@ class DummyAlgorithmMetadata (DummyAlgorithm):
def train_projector(self, train_files, projector_file, metadata=None):
"""Does nothing, simply converts the data type of the data, ignoring any annotation."""
assert metadata is not None
assert isinstance(metadata, list)
return super(DummyAlgorithmMetadata, self).train_projector(train_files, projector_file)
def enroll(self, enroll_features, metadata=None):
......@@ -74,11 +75,11 @@ class DummyAlgorithmMetadata (DummyAlgorithm):
def score(self, model, probe, metadata=None):
"""Returns the Euclidean distance between model and probe"""
assert metadata is not None
assert isinstance(metadata, BioFile)
return super(DummyAlgorithmMetadata, self).score(model, probe)
def project(self, feature, metadata=None):
assert metadata is not None
assert isinstance(metadata, BioFile)
return super(DummyAlgorithmMetadata, self).project(feature)
algorithm_metadata = DummyAlgorithmMetadata()
import numpy
import bob.bio.base
from bob.bio.base.database import BioFile
from bob.bio.base.extractor import Extractor
_data = [0., 1., 2., 3., 4.]
......@@ -31,8 +31,7 @@ class DummyExtractorMetadata (DummyExtractor):
def __call__(self, data, metadata=None):
"""Does nothing, simply converts the data type of the data, ignoring any annotation."""
assert metadata is not None
assert self.model
return data.astype(numpy.float).flatten()
assert isinstance(metadata, BioFile)
return super(DummyExtractorMetadata, self).__call__(data)
extractor_metadata = DummyExtractorMetadata()
from bob.bio.base.preprocessor import Preprocessor
from bob.bio.base.database import BioFile
import numpy
numpy.random.seed(10)
......@@ -23,7 +24,7 @@ class DummyPreprocessorMetadata (DummyPreprocessor):
def __call__(self, data, annotation, metadata=None):
"""Does nothing, simply converts the data type of the data, ignoring any annotation."""
assert metadata is not None
assert isinstance(metadata, BioFile)
return super(DummyPreprocessorMetadata, self).__call__(data, annotation)
preprocessor_metadata = DummyPreprocessorMetadata()
......@@ -55,7 +55,7 @@ def train_projector(algorithm, extractor, allow_missing_files = False, force = F
logger.info("- Projection: training projector '%s' using %d training files: ", fs.projector_file, len(train_files))
# perform training
if "metadata" in inspect.getargspec(algorithm.train_projector).args:
if utils.is_argument_available("metadata", algorithm.train_projector):
metadata = fs.database.training_files('train_projector', algorithm.split_training_features_by_client)
algorithm.train_projector(train_features, fs.projector_file, metadata=metadata)
else:
......@@ -138,7 +138,7 @@ def project(algorithm, extractor, groups = None, indices = None, allow_missing_f
# project feature
if "metadata" in inspect.getargspec(algorithm.project).args:
projected = algorithm.project(feature, metadata=metadata)
projected = algorithm.project(feature, metadata=metadata[i])
else:
projected = algorithm.project(feature)
......@@ -257,6 +257,9 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev
# which tool to use to read the features...
reader = algorithm if algorithm.use_projected_features_for_enrollment else extractor
# Checking if we need to ship the metadata to the method enroll
has_metadata = utils.is_argument_available("metadata", algorithm.enroll)
# Create Models
if 'N' in types:
for group in groups:
......@@ -290,7 +293,7 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev
# load all files into memory
enroll_features = [reader.read_feature(enroll_file) for enroll_file in enroll_files]
if "metadata" in inspect.getargspec(algorithm.enroll).args:
if has_metadata:
metadata = fs.database.enroll_files(group=group, model_id=model_id)
model = algorithm.enroll(enroll_features, metadata=metadata)
else:
......@@ -341,7 +344,7 @@ def enroll(algorithm, extractor, compute_zt_norm, indices = None, groups = ['dev
# load all files into memory
t_enroll_features = [reader.read_feature(t_enroll_file) for t_enroll_file in t_enroll_files]
if "metadata" in inspect.getargspec(algorithm.enroll).args:
if has_metadata:
metadata = fs.database.enroll_files(group=group, model_id=t_model_id)
t_model = algorithm.enroll(t_enroll_features, metadata=metadata)
else:
......
......@@ -91,7 +91,11 @@ def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_
extractor.load(fs.extractor_file)
data_files = fs.preprocessed_data_list(groups=groups)
feature_files = fs.feature_list(groups=groups)
metadata = fs.original_data_list(groups=groups)
if utils.is_argument_available("metadata", extractor.__call__):
metadata = fs.original_data_list(groups=groups)
else:
metadata = None
# select a subset of indices to iterate
if indices is not None:
......@@ -120,11 +124,12 @@ def extract(extractor, preprocessor, groups=None, indices = None, allow_missing_
bob.io.base.create_directories_safe(os.path.dirname(feature_file))
# load data
data = preprocessor.read_data(data_file)
# extract feature
if "metadata" in inspect.getargspec(extractor.__call__).args:
feature = extractor(data, metadata=metadata[i])
else:
if metadata is None:
feature = extractor(data)
else:
feature = extractor(data, metadata=metadata[i])
if feature is None:
if allow_missing_files:
......
......@@ -46,7 +46,11 @@ def preprocess(preprocessor, groups = None, indices = None, allow_missing_files
data_files = fs.original_data_list(groups=groups)
original_directory, original_extension = fs.original_directory_and_extension()
preprocessed_data_files = fs.preprocessed_data_list(groups=groups)
metadata = fs.original_data_list(groups=groups)
if utils.is_argument_available("metadata", preprocessor.__call__):
metadata = fs.original_data_list(groups=groups)
else:
metadata = None
# select a subset of keys to iterate
if indices is not None:
......@@ -80,10 +84,10 @@ def preprocess(preprocessor, groups = None, indices = None, allow_missing_files
annotations = fs.get_annotations(annotation_list[i])
# call the preprocessor
if "metadata" in inspect.getargspec(preprocessor.__call__).args:
preprocessed_data = preprocessor(data, annotations, metadata=metadata[i])
else:
if metadata is None:
preprocessed_data = preprocessor(data, annotations)
else:
preprocessed_data = preprocessor(data, annotations, metadata=metadata[i])
if preprocessed_data is None:
if allow_missing_files:
......
......@@ -27,6 +27,9 @@ def _scores(algorithm, reader, model, probe_objects, allow_missing_files):
# if we have no model, all scores are undefined
return scores
# Checking if we need to ship the metadata in the scoring method
has_metadata = utils.is_argument_available("metadata", algorithm.score)
# Loops over the probe sets
for i, probe_element, probe_metadata in zip(range(len(probes)), probes, probe_objects):
if fs.uses_probe_file_sets():
......@@ -47,8 +50,9 @@ def _scores(algorithm, reader, model, probe_objects, allow_missing_files):
continue
# read probe
probe = reader.read_feature(probe_element)
# compute score
if "metadata" in inspect.getargspec(algorithm.score).args:
if has_metadata:
scores[0, i] = algorithm.score(model, probe, metadata=probe_metadata)
else:
scores[0, i] = algorithm.score(model, probe)
......
......@@ -7,7 +7,8 @@ from .resources import *
from .io import *
from .singleton import *
from . import processors
import six
import inspect
import numpy
def score_fusion_strategy(strategy_name = 'average'):
......@@ -53,6 +54,27 @@ def selected_elements(list_of_elements, desired_number_of_elements = None):
# sub-select
return [list_of_elements[i] for i in selected_indices(total_number_of_elements, desired_number_of_elements)]
def pretty_print(obj, kwargs):
"""Returns a pretty-print of the parameters to the constructor of a class, which should be able to copy-paste on the command line to create the object (with few exceptions)."""
return "%s(%s)" % (str(obj.__class__), ", ".join(["%s='%s'" % (key,value) if isinstance(value, str) else "%s=%s" % (key, value) for key,value in kwargs.items() if value is not None]))
def is_argument_available(argument, method):
"""
Check if an argument (or keyword argument) is available in a method
Attributes
----------
argument: str
The name of the argument (or keyword argument).
method:
Pointer to the method
"""
if six.PY2:
return argument in inspect.getargspec(method).args
else:
return argument in inspect.signature(method).parameters.keys()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment