From fe38ce3f566a88f1138783a055a24fecbb374423 Mon Sep 17 00:00:00 2001 From: Tiago Freitas Pereira <tiagofrepereira@gmail.com> Date: Tue, 16 Jun 2020 09:35:36 +0200 Subject: [PATCH] Created a wrapper that wraps vanilla biometrics pipelines to checkpointin. --- .../pipelines/vanilla_biometrics/wrappers.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index 29ef62f3..15eb02dd 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -11,7 +11,8 @@ import h5py import cloudpickle from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper from .legacy import BioAlgorithmLegacy - +from bob.bio.base.transformers import PreprocessorTransformer, ExtractorTransformer, AlgorithmTransformer +from bob.pipelines.wrappers import SampleWrapper class BioAlgorithmCheckpointWrapper(BioAlgorithm): """Wrapper used to checkpoint enrolled and Scoring samples. @@ -278,13 +279,30 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir): sk_pipeline = pipeline.transformer for i, name, estimator in sk_pipeline._iter(): + # If they are legacy objects, we need to hook their load/save functions + save_func=None + load_func=None + + if not isinstance(estimator, SampleWrapper): + raise ValueError(f"{estimator} needs to be the type `SampleWrapper` to be checkpointed") + + if isinstance(estimator.estimator, PreprocessorTransformer): + save_func = estimator.estimator.instance.write_data + load_func = estimator.estimator.instance.read_data + elif any([isinstance(estimator.estimator, ExtractorTransformer), + isinstance(estimator.estimator, AlgorithmTransformer)]): + save_func = estimator.estimator.instance.write_feature + load_func = estimator.estimator.instance.read_feature + wraped_estimator = bob.pipelines.wrap( - ["checkpoint"], estimator, features_dir=os.path.join(base_dir, name) + ["checkpoint"], estimator, features_dir=os.path.join(base_dir, name), + load_func=load_func, + save_func=save_func ) sk_pipeline.steps[i] = (name, wraped_estimator) - if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): + if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): pipeline.biometric_algorithm.base_dir = base_dir else: pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper( -- GitLab