Commit 55795380 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[legacy] Changed the way vanilla-biometric pipelines are checkpoint wrapped by...

[legacy] Changed the way vanilla-biometric pipelines are checkpoint wrapped by using the bob_features_save_fn and bob_features_load_fn tags
parent 7a96613a
Pipeline #51789 passed with stage
in 11 minutes and 37 seconds
......@@ -320,36 +320,10 @@ def checkpoint_vanilla_biometrics(
for i, name, estimator in sk_pipeline._iter():
# If they are legacy objects, we need to hook their load/save functions
save_func = None
load_func = None
if not isinstance(estimator, SampleWrapper):
raise ValueError(
f"{estimator} needs to be the type `SampleWrapper` to be checkpointed"
)
if isinstance(estimator.estimator, PreprocessorTransformer):
save_func = estimator.estimator.instance.write_data
load_func = estimator.estimator.instance.read_data
elif any(
[
isinstance(estimator.estimator, ExtractorTransformer),
isinstance(estimator.estimator, AlgorithmTransformer),
]
):
save_func = estimator.estimator.instance.write_feature
load_func = estimator.estimator.instance.read_feature
estimator.estimator.projector_file = os.path.join(
bio_ref_scores_dir, "Projector.hdf5"
)
wraped_estimator = bob.pipelines.wrap(
["checkpoint"],
estimator,
features_dir=os.path.join(base_dir, name),
load_func=load_func,
save_func=save_func,
hash_fn=hash_fn,
)
......
......@@ -78,4 +78,6 @@ class AlgorithmTransformer(TransformerMixin, BaseEstimator):
return {
"stateless": not self.instance.requires_projector_training,
"requires_fit": self.instance.requires_projector_training,
"bob_features_save_fn": self.instance.write_feature,
"bob_features_load_fn": self.instance.read_feature,
}
......@@ -59,4 +59,7 @@ class ExtractorTransformer(TransformerMixin, BaseEstimator):
return {
"stateless": not self.instance.requires_training,
"requires_fit": self.instance.requires_training,
"bob_features_save_fn": self.instance.write_feature,
"bob_features_load_fn": self.instance.read_feature,
}
......@@ -33,7 +33,12 @@ class PreprocessorTransformer(TransformerMixin, BaseEstimator):
return [self.instance(data, annot) for data, annot in zip(X, annotations)]
def _more_tags(self):
return {"stateless": True, "requires_fit": False}
return {
"stateless": True,
"requires_fit": False,
"bob_features_save_fn": self.instance.write_data,
"bob_features_load_fn": self.instance.read_data,
}
def fit(self, X, y=None):
return self
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment