diff --git a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py index 02c9a2a0a09f1521ac0964989c556f6b153057d8..0c04f2b34156259d3e297f8c611b6e626c873ca7 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/__init__.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/__init__.py @@ -2,7 +2,7 @@ from .pipelines import VanillaBiometricsPipeline from .biometric_algorithms import Distance from .score_writers import FourColumnsScoreWriter, CSVScoreWriter -from .wrappers import BioAlgorithmCheckpointWrapper, BioAlgorithmDaskWrapper, dask_vanilla_biometrics +from .wrappers import BioAlgorithmCheckpointWrapper, BioAlgorithmDaskWrapper, dask_vanilla_biometrics, checkpoint_vanilla_biometrics from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper diff --git a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py index 2a218b7ab1f3ab964a7fefa1ddbcbc5a7126a1f7..339c06d13e6ad8b97ef2dd89e4d2e9d8426d3aa2 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/legacy.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/legacy.py @@ -54,10 +54,33 @@ class DatabaseConnector(Database): The name of the protocol to generate samples from. To be plugged at :py:method:`bob.db.base.Database.objects`. + allow_scoring_with_all_biometric_references: bool + If True will allow the scoring function to be performed in one shot with multiple probes. + This optimization is useful when all probes needs to be compared with all biometric references AND + your scoring function allows this broadcast computation. + + annotation_type: str + Type of the annotations that the database provide. + Allowed types are: `eyes-center` and `bounding-box` + + fixed_positions: dict + In case database contains one single annotation for all samples. + This is useful for registered databases. """ - def __init__(self, database, **kwargs): + def __init__( + self, + database, + allow_scoring_with_all_biometric_references=True, + annotation_type="eyes-center", + fixed_positions=None, + ** kwargs, + ): self.database = database + self.allow_scoring_with_all_biometric_references = allow_scoring_with_all_biometric_references + self.annotation_type = annotation_type + self.fixed_positions=fixed_positions + def background_model_samples(self): """Returns :py:class:`Sample`'s to train a background model (group diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index fed15da90ea07dfff467b338e193da49378eeec5..29ef62f3de67423448cecb164508ed9760a22741 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -5,11 +5,12 @@ import dask import functools from .score_writers import FourColumnsScoreWriter from .abstract_classes import BioAlgorithm -import bob.pipelines as mario +import bob.pipelines import numpy as np import h5py import cloudpickle from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper +from .legacy import BioAlgorithmLegacy class BioAlgorithmCheckpointWrapper(BioAlgorithm): @@ -38,7 +39,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): """ - def __init__(self, biometric_algorithm, base_dir, group=None, force=False, **kwargs): + def __init__( + self, biometric_algorithm, base_dir, group=None, force=False, **kwargs + ): super().__init__(**kwargs) self.base_dir = base_dir @@ -47,14 +50,18 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): self.biometric_algorithm = biometric_algorithm self.force = force self._biometric_reference_extension = ".hdf5" - self._score_extension = ".pkl" + self._score_extension = ".pkl" def set_score_references_path(self, group): if group is None: - self.biometric_reference_dir = os.path.join(self.base_dir, "biometric_references") + self.biometric_reference_dir = os.path.join( + self.base_dir, "biometric_references" + ) self.score_dir = os.path.join(self.base_dir, "scores") else: - self.biometric_reference_dir = os.path.join(self.base_dir, group, "biometric_references") + self.biometric_reference_dir = os.path.join( + self.base_dir, group, "biometric_references" + ) self.score_dir = os.path.join(self.base_dir, group, "scores") def enroll(self, enroll_features): @@ -113,7 +120,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): def _load(path): return cloudpickle.loads(open(path, "rb").read()) - #with h5py.File(path) as hdf5: + # with h5py.File(path) as hdf5: # return hdf5_to_sample(hdf5) def _make_name(sampleset, biometric_references): @@ -125,7 +132,8 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): return os.path.join(subject, name + suffix) path = os.path.join( - self.score_dir, _make_name(sampleset, biometric_references) + self._score_extension + self.score_dir, + _make_name(sampleset, biometric_references) + self._score_extension, ) if self.force or not os.path.exists(path): @@ -221,7 +229,7 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): """ if isinstance(pipeline, ZTNormPipeline): - # Dasking the first part of the pipelines + # Dasking the first part of the pipelines pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics( pipeline.vanilla_biometrics_pipeline, npartitions ) @@ -231,11 +239,11 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): else: if partition_size is None: - pipeline.transformer = mario.wrap( + pipeline.transformer = bob.pipelines.wrap( ["dask"], pipeline.transformer, npartitions=npartitions ) else: - pipeline.transformer = mario.wrap( + pipeline.transformer = bob.pipelines.wrap( ["dask"], pipeline.transformer, partition_size=partition_size ) pipeline.biometric_algorithm = BioAlgorithmDaskWrapper( @@ -249,3 +257,38 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): pipeline.write_scores = _write_scores return pipeline + + +def checkpoint_vanilla_biometrics(pipeline, base_dir): + """ + Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and + :any:`VanillaBiometrics.biometric_algorithm` to be checkpointed + + Parameters + ---------- + + pipeline: :any:`VanillaBiometrics` + Vanilla Biometrics based pipeline to be dasked + + base_dir: str + Path to store biometric references and scores + + """ + + sk_pipeline = pipeline.transformer + for i, name, estimator in sk_pipeline._iter(): + + wraped_estimator = bob.pipelines.wrap( + ["checkpoint"], estimator, features_dir=os.path.join(base_dir, name) + ) + + sk_pipeline.steps[i] = (name, wraped_estimator) + + if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): + pipeline.biometric_algorithm.base_dir = base_dir + else: + pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper( + pipeline.biometric_algorithm, base_dir=base_dir + ) + + return pipeline