Skip to content
Snippets Groups Projects
Commit 16ea9865 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created optimization that allows you to do fast scoring if such option is...

Created optimization that allows you to do fast scoring if such option is available in the BioAlgorithm AND the database allows you to do it. I still think that the Vanilla Pipeline should be in a class. This will come in a next MR
parent 6ae969d4
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #38588 failed
...@@ -12,10 +12,6 @@ class BioAlgorithm(metaclass=ABCMeta): ...@@ -12,10 +12,6 @@ class BioAlgorithm(metaclass=ABCMeta):
Parameters Parameters
---------- ----------
allow_score_multiple_references: bool
If true will call `self.score_multiple_biometric_references`, at scoring time, to compute scores in one shot with multiple probes.
This optiization is useful when all probes needs to be compared with all biometric references AND
your scoring function allows this broadcast computation.
""" """
...@@ -65,7 +61,7 @@ class BioAlgorithm(metaclass=ABCMeta): ...@@ -65,7 +61,7 @@ class BioAlgorithm(metaclass=ABCMeta):
""" """
pass pass
def score_samples(self, probe_features, biometric_references): def score_samples(self, probe_features, biometric_references, allow_scoring_with_all_biometric_references=False):
"""Scores a new sample against multiple (potential) references """Scores a new sample against multiple (potential) references
Parameters Parameters
...@@ -80,6 +76,12 @@ class BioAlgorithm(metaclass=ABCMeta): ...@@ -80,6 +76,12 @@ class BioAlgorithm(metaclass=ABCMeta):
scoring the input probes, must have an ``id`` attribute that scoring the input probes, must have an ``id`` attribute that
will be used to cross-reference which probes need to be scored. will be used to cross-reference which probes need to be scored.
allow_scoring_with_all_biometric_references: bool
If true will call `self.score_multiple_biometric_references`, at scoring time, to compute scores in one shot with multiple probes.
This optiization is useful when all probes needs to be compared with all biometric references AND
your scoring function allows this broadcast computation.
Returns Returns
------- -------
...@@ -92,10 +94,10 @@ class BioAlgorithm(metaclass=ABCMeta): ...@@ -92,10 +94,10 @@ class BioAlgorithm(metaclass=ABCMeta):
retval = [] retval = []
for p in probe_features: for p in probe_features:
retval.append(self._score_sample_set(p, biometric_references)) retval.append(self._score_sample_set(p, biometric_references, allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references))
return retval return retval
def _score_sample_set(self, sampleset, biometric_references): def _score_sample_set(self, sampleset, biometric_references, allow_scoring_with_all_biometric_references):
"""Given a sampleset for probing, compute the scores and retures a sample set with the scores """Given a sampleset for probing, compute the scores and retures a sample set with the scores
""" """
...@@ -117,7 +119,7 @@ class BioAlgorithm(metaclass=ABCMeta): ...@@ -117,7 +119,7 @@ class BioAlgorithm(metaclass=ABCMeta):
# Creating one sample per comparison # Creating one sample per comparison
subprobe_scores = [] subprobe_scores = []
if self.allow_score_multiple_references: if allow_scoring_with_all_biometric_references:
# Multiple scoring # Multiple scoring
if self.stacked_biometric_references is None: if self.stacked_biometric_references is None:
self.stacked_biometric_references = [ self.stacked_biometric_references = [
......
...@@ -202,7 +202,12 @@ def _get_pickable_method(method): ...@@ -202,7 +202,12 @@ def _get_pickable_method(method):
class Preprocessor(CheckpointMixin, SampleMixin, _Preprocessor): class Preprocessor(CheckpointMixin, SampleMixin, _Preprocessor):
def __init__(self, callable, transform_extra_arguments=(("annotations", "annotations"),), **kwargs): def __init__(
self,
callable,
transform_extra_arguments=(("annotations", "annotations"),),
**kwargs,
):
instance = callable() instance = callable()
super().__init__( super().__init__(
callable=callable, callable=callable,
...@@ -240,7 +245,10 @@ class _Extractor(_NonPickableWrapper, TransformerMixin, BaseEstimator): ...@@ -240,7 +245,10 @@ class _Extractor(_NonPickableWrapper, TransformerMixin, BaseEstimator):
return self return self
def _more_tags(self): def _more_tags(self):
return {"requires_fit": self.instance.requires_training} return {
"requires_fit": self.instance.requires_training,
"stateless": not self.instance.requires_training,
}
class Extractor(CheckpointMixin, SampleMixin, _Extractor): class Extractor(CheckpointMixin, SampleMixin, _Extractor):
...@@ -381,7 +389,12 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -381,7 +389,12 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
# Enroll # Enroll
return self.enroll(sampleset) return self.enroll(sampleset)
def _score_sample_set(self, sampleset, biometric_references): def _score_sample_set(
self,
sampleset,
biometric_references,
allow_scoring_with_all_biometric_references=False,
):
"""Given a sampleset for probing, compute the scores and retures a sample set with the scores """Given a sampleset for probing, compute the scores and retures a sample set with the scores
""" """
...@@ -400,7 +413,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -400,7 +413,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
# Creating one sample per comparison # Creating one sample per comparison
subprobe_scores = [] subprobe_scores = []
if self.allow_score_multiple_references: if allow_scoring_with_all_biometric_references:
if self.stacked_biometric_references is None: if self.stacked_biometric_references is None:
self.stacked_biometric_references = [ self.stacked_biometric_references = [
ref.data for ref in biometric_references ref.data for ref in biometric_references
......
...@@ -65,12 +65,21 @@ class BioAlgCheckpointMixin(CheckpointMixin): ...@@ -65,12 +65,21 @@ class BioAlgCheckpointMixin(CheckpointMixin):
return delayed_enrolled_sample return delayed_enrolled_sample
def _score_sample_set(self, sampleset, biometric_references): def _score_sample_set(
self,
sampleset,
biometric_references,
allow_scoring_with_all_biometric_references=False
):
"""Given a sampleset for probing, compute the scores and retures a sample set with the scores """Given a sampleset for probing, compute the scores and retures a sample set with the scores
""" """
# Computing score # Computing score
scored_sample_set = super()._score_sample_set(sampleset, biometric_references) scored_sample_set = super()._score_sample_set(
sampleset,
biometric_references,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
for s in scored_sample_set: for s in scored_sample_set:
# Checkpointing score # Checkpointing score
path = os.path.join(self.score_dir, str(s.path) + ".txt") path = os.path.join(self.score_dir, str(s.path) + ".txt")
...@@ -89,7 +98,12 @@ class BioAlgDaskMixin: ...@@ -89,7 +98,12 @@ class BioAlgDaskMixin:
) )
return biometric_references return biometric_references
def score_samples(self, probe_features, biometric_references): def score_samples(
self,
probe_features,
biometric_references,
allow_scoring_with_all_biometric_references=False,
):
# TODO: Here, we are sending all computed biometric references to all # TODO: Here, we are sending all computed biometric references to all
# probes. It would be more efficient if only the models related to each # probes. It would be more efficient if only the models related to each
...@@ -100,5 +114,9 @@ class BioAlgDaskMixin: ...@@ -100,5 +114,9 @@ class BioAlgDaskMixin:
all_references = dask.delayed(list)(biometric_references) all_references = dask.delayed(list)(biometric_references)
scores = probe_features.map_partitions(super().score_samples, all_references) scores = probe_features.map_partitions(
super().score_samples,
all_references,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
return scores return scores
...@@ -168,12 +168,19 @@ def vanilla_biometrics( ...@@ -168,12 +168,19 @@ def vanilla_biometrics(
logger.info(f"Running vanilla biometrics for group {group}") logger.info(f"Running vanilla biometrics for group {group}")
allow_scoring_with_all_biometric_references = (
database.allow_scoring_with_all_biometric_references
if hasattr(database, "allow_scoring_with_all_biometric_references")
else False
)
result = biometric_pipeline( result = biometric_pipeline(
database.background_model_samples(), database.background_model_samples(),
biometric_references, biometric_references,
database.probes(group=group), database.probes(group=group),
transformer, transformer,
algorithm, algorithm,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references
) )
if isinstance(result, dask.bag.core.Bag): if isinstance(result, dask.bag.core.Bag):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment