diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py index 29ef62f3de67423448cecb164508ed9760a22741..15eb02dd37f3f6df3c472816d663f3fe7159e217 100644 --- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py +++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py @@ -11,7 +11,8 @@ import h5py import cloudpickle from .zt_norm import ZTNormPipeline, ZTNormDaskWrapper from .legacy import BioAlgorithmLegacy - +from bob.bio.base.transformers import PreprocessorTransformer, ExtractorTransformer, AlgorithmTransformer +from bob.pipelines.wrappers import SampleWrapper class BioAlgorithmCheckpointWrapper(BioAlgorithm): """Wrapper used to checkpoint enrolled and Scoring samples. @@ -278,13 +279,30 @@ def checkpoint_vanilla_biometrics(pipeline, base_dir): sk_pipeline = pipeline.transformer for i, name, estimator in sk_pipeline._iter(): + # If they are legacy objects, we need to hook their load/save functions + save_func=None + load_func=None + + if not isinstance(estimator, SampleWrapper): + raise ValueError(f"{estimator} needs to be the type `SampleWrapper` to be checkpointed") + + if isinstance(estimator.estimator, PreprocessorTransformer): + save_func = estimator.estimator.instance.write_data + load_func = estimator.estimator.instance.read_data + elif any([isinstance(estimator.estimator, ExtractorTransformer), + isinstance(estimator.estimator, AlgorithmTransformer)]): + save_func = estimator.estimator.instance.write_feature + load_func = estimator.estimator.instance.read_feature + wraped_estimator = bob.pipelines.wrap( - ["checkpoint"], estimator, features_dir=os.path.join(base_dir, name) + ["checkpoint"], estimator, features_dir=os.path.join(base_dir, name), + load_func=load_func, + save_func=save_func ) sk_pipeline.steps[i] = (name, wraped_estimator) - if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): + if isinstance(pipeline.biometric_algorithm, BioAlgorithmLegacy): pipeline.biometric_algorithm.base_dir = base_dir else: pipeline.biometric_algorithm = BioAlgorithmCheckpointWrapper(