Commit f8df4c0c authored by Hatef OTROSHI's avatar Hatef OTROSHI
Browse files

add average_on_enroll argument to Distance

parent 1262a6c6
Pipeline #57784 passed with stage
in 16 minutes and 9 seconds
......@@ -10,11 +10,12 @@ import functools
class Distance(BioAlgorithm):
def __init__(
self, distance_function=scipy.spatial.distance.cosine, factor=-1, **kwargs
self, distance_function=scipy.spatial.distance.cosine, factor=-1, average_on_enroll=True, **kwargs
):
super().__init__(**kwargs)
self.distance_function = distance_function
self.factor = factor
self.average_on_enroll = average_on_enroll # if True average of features is calculated, if False average of scores is calculated
def _make_2d(self, X):
"""
......@@ -58,7 +59,7 @@ class Distance(BioAlgorithm):
# That dumps vectors in the format `Nx1xd`
assert enroll_features.ndim == 2
return np.mean(enroll_features, axis=0)
return np.mean(enroll_features, axis=0) if self.average_on_enroll else enroll_features
def score(self, biometric_reference, data):
"""score(model, probe) -> float
......@@ -91,7 +92,9 @@ class Distance(BioAlgorithm):
assert data.ndim == 2
# return the negative distance (as a similarity measure)
return self.factor * self.distance_function(biometric_reference, data)
scores = self.factor * self.distance_function(biometric_reference, data)
return scores if self.average_on_enroll else np.mean(scores)
def score_multiple_biometric_references(self, biometric_references, data):
......@@ -106,39 +109,4 @@ class Distance(BioAlgorithm):
references_stacked = np.vstack(biometric_references)
scores = self.factor * cdist(references_stacked, data, self.distance_function)
return scores
class Distance2(Distance):
def __init__(
self, distance_function=scipy.spatial.distance.cosine, factor=-1, **kwargs
):
super().__init__(distance_function=distance_function, factor=factor, **kwargs)
def enroll(self, enroll_features):
"""enroll(enroll_features) -> model
Enrolls the model by storing all given input vectors.
Parameters
----------
``enroll_features`` : [:py:class:`numpy.ndarray`]
The list of projected features to enroll the model from.
Returns
-------
``model`` : 2D :py:class:`numpy.ndarray`
The enrolled model.
"""
enroll_features = check_array(enroll_features, allow_nd=True, ensure_2d=True)
enroll_features = self._make_2d(enroll_features)
# This avoids some possible mistakes in the feature extraction
# That dumps vectors in the format `Nx1xd`
assert enroll_features.ndim == 2
return enroll_features
\ No newline at end of file
return scores
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment