Skip to content
Snippets Groups Projects
Commit fe38ce3f authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created a wrapper that wraps vanilla biometrics pipelines to checkpointin.

parent da5069ed
No related branches found
No related tags found
2 merge requests!192Redoing baselines,!180[dask] Preparing bob.bio.base for dask pipelines
......@@ -11,7 +11,8 @@ import h5py
import cloudpickle
from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper
from .legacy import BioAlgorithmLegacy
from bob.bio.base.transformers import PreprocessorTransformer, ExtractorTransformer, AlgorithmTransformer
from bob.pipelines.wrappers import SampleWrapper
class BioAlgorithmCheckpointWrapper(BioAlgorithm):
"""Wrapper used to checkpoint enrolled and Scoring samples.
......@@ -278,13 +279,30 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir):
sk_pipeline = pipeline.transformer
for i, name, estimator in sk_pipeline._iter():
# If they are legacy objects, we need to hook their load/save functions
save_func=None
load_func=None
if not isinstance(estimator, SampleWrapper):
raise ValueError(f"{estimator} needs to be the type `SampleWrapper` to be checkpointed")
if isinstance(estimator.estimator, PreprocessorTransformer):
save_func = estimator.estimator.instance.write_data
load_func = estimator.estimator.instance.read_data
elif any([isinstance(estimator.estimator, ExtractorTransformer),
isinstance(estimator.estimator, AlgorithmTransformer)]):
save_func = estimator.estimator.instance.write_feature
load_func = estimator.estimator.instance.read_feature
wraped_estimator = bob.pipelines.wrap(
["checkpoint"], estimator, features_dir=os.path.join(base_dir, name)
["checkpoint"], estimator, features_dir=os.path.join(base_dir, name),
load_func=load_func,
save_func=save_func
)
sk_pipeline.steps[i] = (name, wraped_estimator)
if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy):
if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy):
pipeline.biometric_algorithm.base_dir = base_dir
else:
pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment