Skip to content
Snippets Groups Projects
Commit 13bb6c12 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created checkpoint wrapper

parent e3e4e972
No related branches found
No related tags found
2 merge requests!192Redoing baselines,!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #40478 failed
...@@ -2,7 +2,7 @@ from .pipelines import VanillaBiometricsPipeline ...@@ -2,7 +2,7 @@ from .pipelines import VanillaBiometricsPipeline
from .biometric_algorithms import Distance from .biometric_algorithms import Distance
from .score_writers import FourColumnsScoreWriter, CSVScoreWriter from .score_writers import FourColumnsScoreWriter, CSVScoreWriter
from .wrappers import BioAlgorithmCheckpointWrapper, BioAlgorithmDaskWrapper, dask_vanilla_biometrics from .wrappers import BioAlgorithmCheckpointWrapper, BioAlgorithmDaskWrapper, dask_vanilla_biometrics, checkpoint_vanilla_biometrics
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper, ZTNormCheckpointWrapper
......
...@@ -54,10 +54,33 @@ class DatabaseConnector(Database): ...@@ -54,10 +54,33 @@ class DatabaseConnector(Database):
The name of the protocol to generate samples from. The name of the protocol to generate samples from.
To be plugged at :py:method:`bob.db.base.Database.objects`. To be plugged at :py:method:`bob.db.base.Database.objects`.
allow_scoring_with_all_biometric_references: bool
If True will allow the scoring function to be performed in one shot with multiple probes.
This optimization is useful when all probes needs to be compared with all biometric references AND
your scoring function allows this broadcast computation.
annotation_type: str
Type of the annotations that the database provide.
Allowed types are: `eyes-center` and `bounding-box`
fixed_positions: dict
In case database contains one single annotation for all samples.
This is useful for registered databases.
""" """
def __init__(self, database, **kwargs): def __init__(
self,
database,
allow_scoring_with_all_biometric_references=True,
annotation_type="eyes-center",
fixed_positions=None,
** kwargs,
):
self.database = database self.database = database
self.allow_scoring_with_all_biometric_references = allow_scoring_with_all_biometric_references
self.annotation_type = annotation_type
self.fixed_positions=fixed_positions
def background_model_samples(self): def background_model_samples(self):
"""Returns :py:class:`Sample`'s to train a background model (group """Returns :py:class:`Sample`'s to train a background model (group
......
...@@ -5,11 +5,12 @@ import dask ...@@ -5,11 +5,12 @@ import dask
import functools import functools
from .score_writers import FourColumnsScoreWriter from .score_writers import FourColumnsScoreWriter
from .abstract_classes import BioAlgorithm from .abstract_classes import BioAlgorithm
import bob.pipelines as mario import bob.pipelines
import numpy as np import numpy as np
import h5py import h5py
import cloudpickle import cloudpickle
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
from .legacy import BioAlgorithmLegacy
class BioAlgorithmCheckpointWrapper(BioAlgorithm): class BioAlgorithmCheckpointWrapper(BioAlgorithm):
...@@ -38,7 +39,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -38,7 +39,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
""" """
def __init__(self, biometric_algorithm, base_dir, group=None, force=False, **kwargs): def __init__(
self, biometric_algorithm, base_dir, group=None, force=False, **kwargs
):
super().__init__(**kwargs) super().__init__(**kwargs)
self.base_dir = base_dir self.base_dir = base_dir
...@@ -47,14 +50,18 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -47,14 +50,18 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
self.biometric_algorithm = biometric_algorithm self.biometric_algorithm = biometric_algorithm
self.force = force self.force = force
self._biometric_reference_extension = ".hdf5" self._biometric_reference_extension = ".hdf5"
self._score_extension = ".pkl" self._score_extension = ".pkl"
def set_score_references_path(self, group): def set_score_references_path(self, group):
if group is None: if group is None:
self.biometric_reference_dir = os.path.join(self.base_dir, "biometric_references") self.biometric_reference_dir = os.path.join(
self.base_dir, "biometric_references"
)
self.score_dir = os.path.join(self.base_dir, "scores") self.score_dir = os.path.join(self.base_dir, "scores")
else: else:
self.biometric_reference_dir = os.path.join(self.base_dir, group, "biometric_references") self.biometric_reference_dir = os.path.join(
self.base_dir, group, "biometric_references"
)
self.score_dir = os.path.join(self.base_dir, group, "scores") self.score_dir = os.path.join(self.base_dir, group, "scores")
def enroll(self, enroll_features): def enroll(self, enroll_features):
...@@ -113,7 +120,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -113,7 +120,7 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
def _load(path): def _load(path):
return cloudpickle.loads(open(path, "rb").read()) return cloudpickle.loads(open(path, "rb").read())
#with h5py.File(path) as hdf5: # with h5py.File(path) as hdf5:
# return hdf5_to_sample(hdf5) # return hdf5_to_sample(hdf5)
def _make_name(sampleset, biometric_references): def _make_name(sampleset, biometric_references):
...@@ -125,7 +132,8 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm): ...@@ -125,7 +132,8 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
return os.path.join(subject, name + suffix) return os.path.join(subject, name + suffix)
path = os.path.join( path = os.path.join(
self.score_dir, _make_name(sampleset, biometric_references) + self._score_extension self.score_dir,
_make_name(sampleset, biometric_references) + self._score_extension,
) )
if self.force or not os.path.exists(path): if self.force or not os.path.exists(path):
...@@ -221,7 +229,7 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): ...@@ -221,7 +229,7 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
""" """
if isinstance(pipeline, ZTNormPipeline): if isinstance(pipeline, ZTNormPipeline):
# Dasking the first part of the pipelines # Dasking the first part of the pipelines
pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics( pipeline.vanilla_biometrics_pipeline = dask_vanilla_biometrics(
pipeline.vanilla_biometrics_pipeline, npartitions pipeline.vanilla_biometrics_pipeline, npartitions
) )
...@@ -231,11 +239,11 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): ...@@ -231,11 +239,11 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
else: else:
if partition_size is None: if partition_size is None:
pipeline.transformer = mario.wrap( pipeline.transformer = bob.pipelines.wrap(
["dask"], pipeline.transformer, npartitions=npartitions ["dask"], pipeline.transformer, npartitions=npartitions
) )
else: else:
pipeline.transformer = mario.wrap( pipeline.transformer = bob.pipelines.wrap(
["dask"], pipeline.transformer, partition_size=partition_size ["dask"], pipeline.transformer, partition_size=partition_size
) )
pipeline.biometric_algorithm = BioAlgorithmDaskWrapper( pipeline.biometric_algorithm = BioAlgorithmDaskWrapper(
...@@ -249,3 +257,38 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None): ...@@ -249,3 +257,38 @@ def dask_vanilla_biometrics(pipeline, npartitions=None, partition_size=None):
pipeline.write_scores = _write_scores pipeline.write_scores = _write_scores
return pipeline return pipeline
def checkpoint_vanilla_biometrics(pipeline, base_dir):
"""
Given a :any:`VanillaBiometrics`, wraps :any:`VanillaBiometrics.transformer` and
:any:`VanillaBiometrics.biometric_algorithm` to be checkpointed
Parameters
----------
pipeline: :any:`VanillaBiometrics`
Vanilla Biometrics based pipeline to be dasked
base_dir: str
Path to store biometric references and scores
"""
sk_pipeline = pipeline.transformer
for i, name, estimator in sk_pipeline._iter():
wraped_estimator = bob.pipelines.wrap(
["checkpoint"], estimator, features_dir=os.path.join(base_dir, name)
)
sk_pipeline.steps[i] = (name, wraped_estimator)
if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy):
pipeline.biometric_algorithm.base_dir = base_dir
else:
pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
pipeline.biometric_algorithm, base_dir=base_dir
)
return pipeline
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment