Skip to content
Snippets Groups Projects
Commit 16ea9865 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created optimization that allows you to do fast scoring if such option is...

Created optimization that allows you to do fast scoring if such option is available in the BioAlgorithm AND the database allows you to do it. I still think that the Vanilla Pipeline should be in a class. This will come in a next MR
parent 6ae969d4
Branches
Tags
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #38588 failed
......@@ -12,10 +12,6 @@ class BioAlgorithm(metaclass=ABCMeta):
Parameters
----------
allow_score_multiple_references: bool
If true will call `self.score_multiple_biometric_references`, at scoring time, to compute scores in one shot with multiple probes.
This optiization is useful when all probes needs to be compared with all biometric references AND
your scoring function allows this broadcast computation.
"""
......@@ -65,7 +61,7 @@ class BioAlgorithm(metaclass=ABCMeta):
"""
pass
def score_samples(self, probe_features, biometric_references):
def score_samples(self, probe_features, biometric_references, allow_scoring_with_all_biometric_references=False):
"""Scores a new sample against multiple (potential) references
Parameters
......@@ -80,6 +76,12 @@ class BioAlgorithm(metaclass=ABCMeta):
scoring the input probes, must have an ``id`` attribute that
will be used to cross-reference which probes need to be scored.
allow_scoring_with_all_biometric_references: bool
If true will call `self.score_multiple_biometric_references`, at scoring time, to compute scores in one shot with multiple probes.
This optiization is useful when all probes needs to be compared with all biometric references AND
your scoring function allows this broadcast computation.
Returns
-------
......@@ -92,10 +94,10 @@ class BioAlgorithm(metaclass=ABCMeta):
retval = []
for p in probe_features:
retval.append(self._score_sample_set(p, biometric_references))
retval.append(self._score_sample_set(p, biometric_references, allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references))
return retval
def _score_sample_set(self, sampleset, biometric_references):
def _score_sample_set(self, sampleset, biometric_references, allow_scoring_with_all_biometric_references):
"""Given a sampleset for probing, compute the scores and retures a sample set with the scores
"""
......@@ -117,7 +119,7 @@ class BioAlgorithm(metaclass=ABCMeta):
# Creating one sample per comparison
subprobe_scores = []
if self.allow_score_multiple_references:
if allow_scoring_with_all_biometric_references:
# Multiple scoring
if self.stacked_biometric_references is None:
self.stacked_biometric_references = [
......
......@@ -202,7 +202,12 @@ def _get_pickable_method(method):
class Preprocessor(CheckpointMixin, SampleMixin, _Preprocessor):
def __init__(self, callable, transform_extra_arguments=(("annotations", "annotations"),), **kwargs):
def __init__(
self,
callable,
transform_extra_arguments=(("annotations", "annotations"),),
**kwargs,
):
instance = callable()
super().__init__(
callable=callable,
......@@ -240,7 +245,10 @@ class _Extractor(_NonPickableWrapper, TransformerMixin, BaseEstimator):
return self
def _more_tags(self):
return {"requires_fit": self.instance.requires_training}
return {
"requires_fit": self.instance.requires_training,
"stateless": not self.instance.requires_training,
}
class Extractor(CheckpointMixin, SampleMixin, _Extractor):
......@@ -381,7 +389,12 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
# Enroll
return self.enroll(sampleset)
def _score_sample_set(self, sampleset, biometric_references):
def _score_sample_set(
self,
sampleset,
biometric_references,
allow_scoring_with_all_biometric_references=False,
):
"""Given a sampleset for probing, compute the scores and retures a sample set with the scores
"""
......@@ -400,7 +413,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
# Creating one sample per comparison
subprobe_scores = []
if self.allow_score_multiple_references:
if allow_scoring_with_all_biometric_references:
if self.stacked_biometric_references is None:
self.stacked_biometric_references = [
ref.data for ref in biometric_references
......
......@@ -65,12 +65,21 @@ class BioAlgCheckpointMixin(CheckpointMixin):
return delayed_enrolled_sample
def _score_sample_set(self, sampleset, biometric_references):
def _score_sample_set(
self,
sampleset,
biometric_references,
allow_scoring_with_all_biometric_references=False
):
"""Given a sampleset for probing, compute the scores and retures a sample set with the scores
"""
# Computing score
scored_sample_set = super()._score_sample_set(sampleset, biometric_references)
scored_sample_set = super()._score_sample_set(
sampleset,
biometric_references,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
for s in scored_sample_set:
# Checkpointing score
path = os.path.join(self.score_dir, str(s.path) + ".txt")
......@@ -89,7 +98,12 @@ class BioAlgDaskMixin:
)
return biometric_references
def score_samples(self, probe_features, biometric_references):
def score_samples(
self,
probe_features,
biometric_references,
allow_scoring_with_all_biometric_references=False,
):
# TODO: Here, we are sending all computed biometric references to all
# probes. It would be more efficient if only the models related to each
......@@ -100,5 +114,9 @@ class BioAlgDaskMixin:
all_references = dask.delayed(list)(biometric_references)
scores = probe_features.map_partitions(super().score_samples, all_references)
scores = probe_features.map_partitions(
super().score_samples,
all_references,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
return scores
......@@ -168,12 +168,19 @@ def vanilla_biometrics(
logger.info(f"Running vanilla biometrics for group {group}")
allow_scoring_with_all_biometric_references = (
database.allow_scoring_with_all_biometric_references
if hasattr(database, "allow_scoring_with_all_biometric_references")
else False
)
result = biometric_pipeline(
database.background_model_samples(),
biometric_references,
database.probes(group=group),
transformer,
algorithm,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references
)
if isinstance(result, dask.bag.core.Bag):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment