diff --git a/bob/bio/base/config/examples/pca_atnt.py b/bob/bio/base/config/examples/pca_atnt.py index 7ae7531361be68462a563f1ae2b35fae126e1e67..3cd159b1d621d98a71e25e8b76d2a03d9ef05330 100644 --- a/bob/bio/base/config/examples/pca_atnt.py +++ b/bob/bio/base/config/examples/pca_atnt.py @@ -18,7 +18,7 @@ transformer = make_pipeline( features_dir=os.path.join(base_dir, "pca_features"), model_path=os.path.join(base_dir, "pca.pkl") ), ) -algorithm = CheckpointDistance(features_dir=base_dir) +algorithm = CheckpointDistance(features_dir=base_dir, allow_score_multiple_references=True) # # comment out the code below to disable dask from bob.pipelines.mixins import estimator_dask_it, mix_me_up diff --git a/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py b/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py index 97fa6e70fe928af2709c369c68af2b5a565f37ad..e040c0f193506ceff03842bfa490531748b1fda6 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py @@ -9,8 +9,20 @@ class BioAlgorithm(metaclass=ABCMeta): biometric model enrollement, via ``enroll()`` and scoring, with ``score()``. + Parameters + ---------- + + allow_score_multiple_references: bool + If true will call `self.score_multiple_biometric_references`, at scoring time, to compute scores in one shot with multiple probes. + This optiization is useful when all probes needs to be compared with all biometric references AND + your scoring function allows this broadcast computation. + """ + def __init__(self, allow_score_multiple_references=False): + self.allow_score_multiple_references = allow_score_multiple_references + self.stacked_biometric_references = None + def enroll_samples(self, biometric_references): """This method should implement the sub-pipeline 1 of the Vanilla Biometrics Pipeline :ref:`_vanilla-pipeline-1`. @@ -96,17 +108,35 @@ class BioAlgorithm(metaclass=ABCMeta): # a sampleset either after or before scoring. # To be honest, this should be the default behaviour retval = [] + + def _write_sample(ref, probe, score): + data = make_four_colums_score(ref.subject, probe.subject, probe.path, score) + return Sample(data, parent=ref) + for subprobe_id, (s, parent) in enumerate(zip(data, sampleset.samples)): # Creating one sample per comparison subprobe_scores = [] - for ref in [ - r for r in biometric_references if r.key in sampleset.references - ]: - score = self.score(ref.data, s) - data = make_four_colums_score( - ref.subject, sampleset.subject, sampleset.path, score + + if self.allow_score_multiple_references: + # Multiple scoring + if self.stacked_biometric_references is None: + self.stacked_biometric_references = [ + ref.data for ref in biometric_references + ] + scores = self.score_multiple_biometric_references( + self.stacked_biometric_references, s ) - subprobe_scores.append(Sample(data, parent=ref)) + + # Wrapping the scores in samples + for ref, score in zip(biometric_references, scores): + subprobe_scores.append(_write_sample(ref, sampleset, score[0])) + else: + + for ref in [ + r for r in biometric_references if r.key in sampleset.references + ]: + score = self.score(ref.data, s) + subprobe_scores.append(_write_sample(ref, sampleset, score)) # Creating one sampleset per probe subprobe = SampleSet(subprobe_scores, parent=sampleset) @@ -139,6 +169,23 @@ class BioAlgorithm(metaclass=ABCMeta): """ pass + @abstractmethod + def score_multiple_biometric_references(self, biometric_references, data): + """ + It handles the score computation of one probe and multiple biometric references + This method is called is called if `allow_scoring_multiple_references` is set to true + + Parameters + ---------- + + biometric_references: list + List of biometric references to be scored + data: + Data used for the creation of ONE BIOMETRIC REFERENCE + + """ + pass + class Database(metaclass=ABCMeta): """Base class for Vanilla Biometric pipeline diff --git a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py index 725d99bbcc20a4f69e910121de70cdfe3edcfc6e..0681fbee432d9e1eac3873fe63000eb8d52cf279 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py @@ -173,12 +173,12 @@ class _NonPickableWrapper: def __setstate__(self, d): # Handling unpicklable objects self._instance = None - return super().__setstate__(d) + #return super().__setstate__(d) def __getstate__(self): # Handling unpicklable objects self._instance = None - return super().__getstate__() + #return super().__getstate__() class _Preprocessor(_NonPickableWrapper, TransformerMixin, BaseEstimator):