Skip to content
Snippets Groups Projects
Commit ce29f3bb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented extra wrappers for bob legacy and ported some default options from cd ..

parent ca62d02e
No related branches found
No related tags found
2 merge requests!185Wrappers and aggregators,!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #39657 passed
...@@ -16,10 +16,12 @@ import tempfile ...@@ -16,10 +16,12 @@ import tempfile
import os import os
import bob.io.base import bob.io.base
from bob.bio.base.wrappers import ( from bob.bio.base.wrappers import (
wrap_preprocessor, wrap_checkpoint_preprocessor,
wrap_extractor, wrap_checkpoint_extractor,
wrap_algorithm, wrap_checkpoint_algorithm,
wrap_bob_legacy, wrap_sample_preprocessor,
wrap_sample_extractor,
wrap_sample_algorithm,
) )
from sklearn.pipeline import make_pipeline from sklearn.pipeline import make_pipeline
...@@ -30,7 +32,7 @@ class FakePreprocesor(Preprocessor): ...@@ -30,7 +32,7 @@ class FakePreprocesor(Preprocessor):
class FakeExtractor(Extractor): class FakeExtractor(Extractor):
def __call__(self, data, metadata=None): def __call__(self, data):
return data.flatten() return data.flatten()
...@@ -56,7 +58,7 @@ class FakeAlgorithm(Algorithm): ...@@ -56,7 +58,7 @@ class FakeAlgorithm(Algorithm):
self.split_training_features_by_client = True self.split_training_features_by_client = True
self.model = None self.model = None
def project(self, data, metadata=None): def project(self, data):
return data + self.model return data + self.model
def train_projector(self, training_features, projector_file): def train_projector(self, training_features, projector_file):
...@@ -259,25 +261,30 @@ def test_algorithm(): ...@@ -259,25 +261,30 @@ def test_algorithm():
def test_wrap_bob_pipeline(): def test_wrap_bob_pipeline():
def run_pipeline(with_dask): def run_pipeline(with_dask, with_checkpoint):
with tempfile.TemporaryDirectory() as dir_name: with tempfile.TemporaryDirectory() as dir_name:
if with_checkpoint:
pipeline = make_pipeline(
wrap_checkpoint_preprocessor(FakePreprocesor(), dir_name,),
wrap_checkpoint_extractor(FakeExtractor(), dir_name,),
wrap_checkpoint_algorithm(FakeAlgorithm(), dir_name),
)
else:
pipeline = make_pipeline(
wrap_sample_preprocessor(FakePreprocesor()),
wrap_sample_extractor(FakeExtractor(), dir_name,),
wrap_sample_algorithm(FakeAlgorithm(), dir_name),
)
pipeline = make_pipeline(
wrap_bob_legacy(
FakePreprocesor(),
dir_name,
transform_extra_arguments=(("annotations", "annotations"),),
),
wrap_bob_legacy(FakeExtractor(), dir_name,),
wrap_bob_legacy(FakeAlgorithm(), dir_name),
)
oracle = [7.0, 7.0, 7.0, 7.0] oracle = [7.0, 7.0, 7.0, 7.0]
training_samples = generate_samples(n_subjects=2, n_samples_per_subject=2) training_samples = generate_samples(n_subjects=2, n_samples_per_subject=2)
test_samples = generate_samples(n_subjects=1, n_samples_per_subject=1) test_samples = generate_samples(n_subjects=1, n_samples_per_subject=1)
if with_dask: if with_dask:
pipeline = mario.wrap(["dask"], pipeline) pipeline = mario.wrap(["dask"], pipeline)
transformed_samples = ( transformed_samples = (
pipeline.fit(training_samples).transform(test_samples).compute(scheduler="single-threaded") pipeline.fit(training_samples)
.transform(test_samples)
.compute(scheduler="single-threaded")
) )
else: else:
transformed_samples = pipeline.fit(training_samples).transform( transformed_samples = pipeline.fit(training_samples).transform(
...@@ -285,5 +292,7 @@ def test_wrap_bob_pipeline(): ...@@ -285,5 +292,7 @@ def test_wrap_bob_pipeline():
) )
assert assert_sample(transformed_samples, oracle) assert assert_sample(transformed_samples, oracle)
run_pipeline(False) run_pipeline(False, False)
run_pipeline(True) run_pipeline(False, True)
run_pipeline(True, False)
run_pipeline(True, True)
...@@ -8,7 +8,11 @@ import numpy as np ...@@ -8,7 +8,11 @@ import numpy as np
import tempfile import tempfile
from sklearn.pipeline import make_pipeline from sklearn.pipeline import make_pipeline
from bob.bio.base.wrappers import wrap_bob_legacy from bob.bio.base.wrappers import wrap_bob_legacy
from bob.bio.base.test.test_transformers import FakePreprocesor, FakeExtractor, FakeAlgorithm from bob.bio.base.test.test_transformers import (
FakePreprocesor,
FakeExtractor,
FakeAlgorithm,
)
from bob.bio.base.pipelines.vanilla_biometrics import ( from bob.bio.base.pipelines.vanilla_biometrics import (
Distance, Distance,
VanillaBiometricsPipeline, VanillaBiometricsPipeline,
...@@ -16,7 +20,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import ( ...@@ -16,7 +20,7 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
dask_vanilla_biometrics, dask_vanilla_biometrics,
FourColumnsScoreWriter, FourColumnsScoreWriter,
CSVScoreWriter, CSVScoreWriter,
BioAlgorithmLegacy BioAlgorithmLegacy,
) )
import bob.pipelines as mario import bob.pipelines as mario
...@@ -24,6 +28,7 @@ import uuid ...@@ -24,6 +28,7 @@ import uuid
import shutil import shutil
import itertools import itertools
class DummyDatabase: class DummyDatabase:
def __init__(self, delayed=False, n_references=10, n_probes=10, dim=10, one_d=True): def __init__(self, delayed=False, n_references=10, n_probes=10, dim=10, one_d=True):
self.delayed = delayed self.delayed = delayed
...@@ -36,13 +41,23 @@ class DummyDatabase: ...@@ -36,13 +41,23 @@ class DummyDatabase:
def _create_random_1dsamples(self, n_samples, offset, dim): def _create_random_1dsamples(self, n_samples, offset, dim):
return [ return [
Sample(np.random.rand(dim), key=str(uuid.uuid4()), annotations=1, subject=str(i)) Sample(
np.random.rand(dim),
key=str(uuid.uuid4()),
annotations=1,
subject=str(i),
)
for i in range(offset, offset + n_samples) for i in range(offset, offset + n_samples)
] ]
def _create_random_2dsamples(self, n_samples, offset, dim): def _create_random_2dsamples(self, n_samples, offset, dim):
return [ return [
Sample(np.random.rand(dim, dim), key=str(uuid.uuid4()), annotations=1, subject=str(i)) Sample(
np.random.rand(dim, dim),
key=str(uuid.uuid4()),
annotations=1,
subject=str(i),
)
for i in range(offset, offset + n_samples) for i in range(offset, offset + n_samples)
] ]
...@@ -74,7 +89,7 @@ class DummyDatabase: ...@@ -74,7 +89,7 @@ class DummyDatabase:
return sample_set return sample_set
def background_model_samples(self): def background_model_samples(self):
samples = [sset.samples for sset in self._create_random_sample_set()] samples = [sset.samples for sset in self._create_random_sample_set()]
return list(itertools.chain(*samples)) return list(itertools.chain(*samples))
def references(self): def references(self):
...@@ -101,11 +116,12 @@ def _make_transformer(dir_name): ...@@ -101,11 +116,12 @@ def _make_transformer(dir_name):
dir_name, dir_name,
transform_extra_arguments=(("annotations", "annotations"),), transform_extra_arguments=(("annotations", "annotations"),),
), ),
wrap_bob_legacy(FakeExtractor(), dir_name,) wrap_bob_legacy(FakeExtractor(), dir_name,),
) )
return pipeline return pipeline
def _make_transformer_with_algorithm(dir_name): def _make_transformer_with_algorithm(dir_name):
pipeline = make_pipeline( pipeline = make_pipeline(
wrap_bob_legacy( wrap_bob_legacy(
...@@ -114,7 +130,7 @@ def _make_transformer_with_algorithm(dir_name): ...@@ -114,7 +130,7 @@ def _make_transformer_with_algorithm(dir_name):
transform_extra_arguments=(("annotations", "annotations"),), transform_extra_arguments=(("annotations", "annotations"),),
), ),
wrap_bob_legacy(FakeExtractor(), dir_name), wrap_bob_legacy(FakeExtractor(), dir_name),
wrap_bob_legacy(FakeAlgorithm(), dir_name) wrap_bob_legacy(FakeAlgorithm(), dir_name),
) )
return pipeline return pipeline
...@@ -197,7 +213,9 @@ def test_checkpoint_bioalg_as_transformer(): ...@@ -197,7 +213,9 @@ def test_checkpoint_bioalg_as_transformer():
if isinstance(score_writer, CSVScoreWriter): if isinstance(score_writer, CSVScoreWriter):
base_path = os.path.join(dir_name, "concatenated_scores") base_path = os.path.join(dir_name, "concatenated_scores")
score_writer.concatenate_write_scores(scores, base_path) score_writer.concatenate_write_scores(scores, base_path)
assert len(open(os.path.join(base_path, "chunk_0.csv")).readlines()) == 101 assert (
len(open(os.path.join(base_path, "chunk_0.csv")).readlines()) == 101
)
else: else:
filename = os.path.join(dir_name, "concatenated_scores.txt") filename = os.path.join(dir_name, "concatenated_scores.txt")
score_writer.concatenate_write_scores(scores, filename) score_writer.concatenate_write_scores(scores, filename)
...@@ -205,24 +223,24 @@ def test_checkpoint_bioalg_as_transformer(): ...@@ -205,24 +223,24 @@ def test_checkpoint_bioalg_as_transformer():
run_pipeline(False) run_pipeline(False)
run_pipeline(False) # Checking if the checkpointng works run_pipeline(False) # Checking if the checkpointng works
shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# Dask # Dask
run_pipeline(True) run_pipeline(True)
run_pipeline(True) # Checking if the checkpointng works run_pipeline(True) # Checking if the checkpointng works
shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# CSVWriter # CSVWriter
run_pipeline(False, CSVScoreWriter()) run_pipeline(False, CSVScoreWriter())
run_pipeline(False, CSVScoreWriter()) # Checking if the checkpointng works run_pipeline(False, CSVScoreWriter()) # Checking if the checkpointng works
shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# CSVWriter + Dask # CSVWriter + Dask
run_pipeline(True, CSVScoreWriter()) run_pipeline(True, CSVScoreWriter())
run_pipeline(True, CSVScoreWriter()) # Checking if the checkpointng works run_pipeline(True, CSVScoreWriter()) # Checking if the checkpointng works
def test_checkpoint_bioalg_as_bioalg(): def test_checkpoint_bioalg_as_bioalg():
...@@ -231,12 +249,15 @@ def test_checkpoint_bioalg_as_bioalg(): ...@@ -231,12 +249,15 @@ def test_checkpoint_bioalg_as_bioalg():
def run_pipeline(with_dask, score_writer=FourColumnsScoreWriter()): def run_pipeline(with_dask, score_writer=FourColumnsScoreWriter()):
database = DummyDatabase() database = DummyDatabase()
transformer = _make_transformer_with_algorithm(dir_name) transformer = _make_transformer_with_algorithm(dir_name)
projector_file = transformer[2].estimator.estimator.projector_file projector_file = transformer[2].estimator.estimator.projector_file
biometric_algorithm = BioAlgorithmLegacy( biometric_algorithm = BioAlgorithmLegacy(
FakeAlgorithm(), base_dir=dir_name, score_writer=score_writer, projector_file=projector_file FakeAlgorithm(),
base_dir=dir_name,
score_writer=score_writer,
projector_file=projector_file,
) )
vanilla_biometrics_pipeline = VanillaBiometricsPipeline( vanilla_biometrics_pipeline = VanillaBiometricsPipeline(
...@@ -265,11 +286,11 @@ def test_checkpoint_bioalg_as_bioalg(): ...@@ -265,11 +286,11 @@ def test_checkpoint_bioalg_as_bioalg():
run_pipeline(False) run_pipeline(False)
run_pipeline(False) # Checking if the checkpointng works run_pipeline(False) # Checking if the checkpointng works
shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# Dask # Dask
run_pipeline(True) run_pipeline(True)
run_pipeline(True) # Checking if the checkpointng works run_pipeline(True) # Checking if the checkpointng works
shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch shutil.rmtree(dir_name) # Deleting the cache so it runs again from scratch
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
...@@ -11,6 +11,7 @@ from bob.bio.base.extractor import Extractor ...@@ -11,6 +11,7 @@ from bob.bio.base.extractor import Extractor
from bob.bio.base.algorithm import Algorithm from bob.bio.base.algorithm import Algorithm
import bob.pipelines as mario import bob.pipelines as mario
import os import os
from bob.bio.base.utils import is_argument_available
def wrap_bob_legacy( def wrap_bob_legacy(
...@@ -18,7 +19,7 @@ def wrap_bob_legacy( ...@@ -18,7 +19,7 @@ def wrap_bob_legacy(
dir_name, dir_name,
fit_extra_arguments=(("y", "subject"),), fit_extra_arguments=(("y", "subject"),),
transform_extra_arguments=None, transform_extra_arguments=None,
dask_it=False dask_it=False,
): ):
""" """
Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor` Wraps either :any:`bob.bio.base.preprocessor.Preprocessor`, :any:`bob.bio.base.extractor.Extractor`
...@@ -47,33 +48,20 @@ def wrap_bob_legacy( ...@@ -47,33 +48,20 @@ def wrap_bob_legacy(
""" """
if isinstance(bob_object, Preprocessor): if isinstance(bob_object, Preprocessor):
preprocessor_transformer = PreprocessorTransformer(bob_object) transformer = wrap_checkpoint_preprocessor(
transformer = wrap_preprocessor( bob_object, features_dir=os.path.join(dir_name, "preprocessor"),
preprocessor_transformer,
features_dir=os.path.join(dir_name, "preprocessor"),
transform_extra_arguments=transform_extra_arguments,
) )
elif isinstance(bob_object, Extractor): elif isinstance(bob_object, Extractor):
extractor_transformer = ExtractorTransformer(bob_object) transformer = wrap_checkpoint_extractor(
path = os.path.join(dir_name, "extractor") bob_object,
transformer = wrap_extractor( features_dir=os.path.join(dir_name, "extractor"),
extractor_transformer, model_path=dir_name,
features_dir=path,
model_path=os.path.join(path, "extractor.pkl"),
transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments,
) )
elif isinstance(bob_object, Algorithm): elif isinstance(bob_object, Algorithm):
path = os.path.join(dir_name, "algorithm") transformer = wrap_checkpoint_algorithm(
algorithm_transformer = AlgorithmTransformer( bob_object,
bob_object, projector_file=os.path.join(path, "Projector.hdf5") features_dir=os.path.join(dir_name, "algorithm"),
) model_path=dir_name,
transformer = wrap_algorithm(
algorithm_transformer,
features_dir=path,
model_path=os.path.join(path, "algorithm.pkl"),
transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments,
) )
else: else:
raise ValueError( raise ValueError(
...@@ -86,132 +74,351 @@ def wrap_bob_legacy( ...@@ -86,132 +74,351 @@ def wrap_bob_legacy(
return transformer return transformer
def wrap_preprocessor( def wrap_sample_preprocessor(
preprocessor_transformer, features_dir=None, transform_extra_arguments=None, preprocessor,
transform_extra_arguments=(("annotations", "annotations"),),
**kwargs
):
"""
Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with
:any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
.. warning::
This wrapper doesn't checkpoint data
Parameters
----------
preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor`
Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped
transform_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
"""
transformer = PreprocessorTransformer(preprocessor)
return mario.wrap(
["sample"],
transformer,
transform_extra_arguments=transform_extra_arguments,
)
def wrap_checkpoint_preprocessor(
preprocessor,
features_dir=None,
transform_extra_arguments=(("annotations", "annotations"),),
load_func=None,
save_func=None,
extension=".hdf5",
): ):
""" """
Wraps :any:`bob.bio.base.transformers.PreprocessorTransformer` with Wraps :any:`bob.bio.base.preprocessor.Preprocessor` with
:any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
Parameters Parameters
---------- ----------
preprocessor_transformer: :any:`bob.bio.base.transformers.PreprocessorTransformer` preprocessor: :any:`bob.bio.base.preprocessor.Preprocessor`
Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped Instance of :any:`bob.bio.base.transformers.PreprocessorTransformer` to be wrapped
features_dir: str features_dir: str
Features directory to be checkpointed Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`).
extension : str, optional
Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`).
load_func : None, optional
Function that loads data to be preprocessed.
The default is :any:`bob.bio.base.preprocessor.Preprocessor.read_data`
save_func : None, optional
Function that saves preprocessed data.
The default is :any:`bob.bio.base.preprocessor.Preprocessor.write_data`
transform_extra_arguments: [tuple] transform_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
""" """
if not isinstance(preprocessor_transformer, PreprocessorTransformer): transformer = PreprocessorTransformer(preprocessor)
raise ValueError(
f"Expected an instance of PreprocessorTransformer, not {preprocessor_transformer}"
)
return mario.wrap( return mario.wrap(
["sample", "checkpoint"], ["sample", "checkpoint"],
preprocessor_transformer, transformer,
load_func=preprocessor_transformer.callable.read_data, load_func=load_func or preprocessor.read_data,
save_func=preprocessor_transformer.callable.write_data, save_func=save_func or preprocessor.write_data,
features_dir=features_dir, features_dir=features_dir,
transform_extra_arguments=transform_extra_arguments, transform_extra_arguments=transform_extra_arguments,
extension=extension,
) )
def wrap_extractor( def _prepare_extractor_sample_args(
extractor_transformer, extractor, transform_extra_arguments, fit_extra_arguments
):
if transform_extra_arguments is None and is_argument_available(
"metadata", extractor.__call__
):
transform_extra_arguments = (("metadata", "metadata"),)
if (
fit_extra_arguments is None
and extractor.requires_training
and extractor.split_training_data_by_client
):
fit_extra_arguments = (("y", "subject"),)
return transform_extra_arguments, fit_extra_arguments
def wrap_sample_extractor(
extractor,
fit_extra_arguments=None, fit_extra_arguments=None,
transform_extra_arguments=None, transform_extra_arguments=None,
model_path=None,
**kwargs,
):
"""
Wraps :any:`bob.bio.base.extractor.Extractor` with
:any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
Parameters
----------
extractor: :any:`bob.bio.base.extractor.Preprocessor`
Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped
transform_extra_arguments: [tuple], optional
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
model_path: str
Path to `extractor_file` in :any:`bob.bio.base.extractor.Extractor`
"""
extractor_file = (
os.path.join(model_path, "Extractor.hdf5") if model_path is not None else None
)
transformer = ExtractorTransformer(extractor, model_path=extractor_file)
transform_extra_arguments, fit_extra_arguments = _prepare_extractor_sample_args(
extractor, transform_extra_arguments, fit_extra_arguments
)
return mario.wrap(
["sample"],
transformer,
transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments,
**kwargs,
)
def wrap_checkpoint_extractor(
extractor,
features_dir=None, features_dir=None,
fit_extra_arguments=None,
transform_extra_arguments=None,
load_func=None,
save_func=None,
extension=".hdf5",
model_path=None, model_path=None,
**kwargs,
): ):
""" """
Wraps :any:`bob.bio.base.transformers.ExtractorTransformer` with Wraps :any:`bob.bio.base.extractor.Extractor` with
:any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
Parameters Parameters
---------- ----------
extractor_transformer: :any:`bob.bio.base.transformers.ExtractorTransformer` extractor: :any:`bob.bio.base.extractor.Preprocessor`
Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped Instance of :any:`bob.bio.base.transformers.ExtractorTransformer` to be wrapped
features_dir: str features_dir: str
Features directory to be checkpointed Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`).
model_path: str extension : str, optional
Path to checkpoint the model Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`).
load_func : None, optional
Function that loads data to be preprocessed.
The default is :any:`bob.bio.base.extractor.Extractor.read_feature`
save_func : None, optional
Function that saves preprocessed data.
The default is :any:`bob.bio.base.extractor.Extractor.write_feature`
fit_extra_arguments: [tuple] fit_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments`
transform_extra_arguments: [tuple] transform_extra_arguments: [tuple], optional
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
model_path: str
See :any:`TransformerExtractor`.
""" """
if not isinstance(extractor_transformer, ExtractorTransformer): extractor_file = (
raise ValueError( os.path.join(model_path, "Extractor.hdf5") if model_path is not None else None
f"Expected an instance of ExtractorTransformer, not {extractor_transformer}" )
)
model_file = (
os.path.join(model_path, "Extractor.pkl") if model_path is not None else None
)
transformer = ExtractorTransformer(extractor, model_path=extractor_file)
transform_extra_arguments, fit_extra_arguments = _prepare_extractor_sample_args(
extractor, transform_extra_arguments, fit_extra_arguments
)
return mario.wrap( return mario.wrap(
["sample", "checkpoint"], ["sample", "checkpoint"],
extractor_transformer, transformer,
load_func=extractor_transformer.callable.read_feature, load_func=load_func or extractor.read_feature,
save_func=extractor_transformer.callable.write_feature, save_func=save_func or extractor.write_feature,
model_path=model_path, model_path=model_file,
features_dir=features_dir, features_dir=features_dir,
transform_extra_arguments=transform_extra_arguments, transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments, fit_extra_arguments=fit_extra_arguments,
**kwargs,
) )
def wrap_algorithm( def _prepare_algorithm_sample_args(
algorithm_transformer, algorithm, transform_extra_arguments, fit_extra_arguments
):
if transform_extra_arguments is None and is_argument_available(
"metadata", algorithm.project
):
transform_extra_arguments = (("metadata", "metadata"),)
if (
fit_extra_arguments is None
and algorithm.requires_projector_training
and algorithm.split_training_features_by_client
):
fit_extra_arguments = (("y", "subject"),)
return transform_extra_arguments, fit_extra_arguments
def wrap_sample_algorithm(
algorithm,
model_path=None,
fit_extra_arguments=None, fit_extra_arguments=None,
transform_extra_arguments=None, transform_extra_arguments=None,
features_dir=None, **kwargs,
):
"""
Wraps :any:`bob.bio.base.algorithm.Algorithm` with
:any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
Parameters
----------
algorithm_transformer: :any:`bob.bio.base.algorithm.Algorithm`
Instance of :any:`bob.bio.base.transformers.AlgorithmTransformer` to be wrapped
model_path: str
Path to `projector_file` in :any:`bob.bio.base.algorithm.Algorithm`
fit_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments`
transform_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
"""
projector_file = (
os.path.join(model_path, "Projector.hdf5") if model_path is not None else None
)
transformer = AlgorithmTransformer(algorithm, projector_file=projector_file)
transform_extra_arguments, fit_extra_arguments = _prepare_algorithm_sample_args(
algorithm, transform_extra_arguments, fit_extra_arguments
)
return mario.wrap(
["sample"],
transformer,
transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments,
)
def wrap_checkpoint_algorithm(
algorithm,
model_path=None, model_path=None,
features_dir=None,
extension=".hdf5",
fit_extra_arguments=None,
transform_extra_arguments=None,
load_func=None,
save_func=None,
**kwargs,
): ):
""" """
Wraps :any:`bob.bio.base.transformers.AlgorithmTransformer` with Wraps :any:`bob.bio.base.algorithm.Algorithm` with
:any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper` :any:`bob.pipelines.wrappers.CheckpointWrapper` and :any:`bob.pipelines.wrappers.SampleWrapper`
Parameters Parameters
---------- ----------
algorithm_transformer: :any:`bob.bio.base.transformers.AlgorithmTransformer` algorithm_transformer: :any:`bob.bio.base.algorithm.Algorithm`
Instance of :any:`bob.bio.base.transformers.AlgorithmTransformer` to be wrapped Instance of :any:`bob.bio.base.transformers.AlgorithmTransformer` to be wrapped
features_dir: str features_dir: str
Features directory to be checkpointed Features directory to be checkpointed (see :any:bob.pipelines.CheckpointWrapper`).
model_path: str model_path: str
Path to checkpoint the model Path to checkpoint the model
extension : str, optional
Extension o preprocessed files (see :any:bob.pipelines.CheckpointWrapper`).
fit_extra_arguments: [tuple] fit_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments` Same behavior as in Check :any:`bob.pipelines.wrappers.fit_extra_arguments`
transform_extra_arguments: [tuple] transform_extra_arguments: [tuple]
Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments` Same behavior as in Check :any:`bob.pipelines.wrappers.transform_extra_arguments`
load_func : None, optional
Function that loads data to be preprocessed.
The default is :any:`bob.bio.base.extractor.Extractor.read_feature`
save_func : None, optional
Function that saves preprocessed data.
The default is :any:`bob.bio.base.extractor.Extractor.write_feature`
""" """
if not isinstance(algorithm_transformer, AlgorithmTransformer): projector_file = (
raise ValueError( os.path.join(model_path, "Projector.hdf5") if model_path is not None else None
f"Expected an instance of AlgorithmTransformer, not {algorithm_transformer}" )
)
model_file = (
os.path.join(model_path, "Projector.pkl") if model_path is not None else None
)
transformer = AlgorithmTransformer(algorithm, projector_file=projector_file)
transform_extra_arguments, fit_extra_arguments = _prepare_algorithm_sample_args(
algorithm, transform_extra_arguments, fit_extra_arguments
)
return mario.wrap( return mario.wrap(
["sample", "checkpoint"], ["sample", "checkpoint"],
algorithm_transformer, transformer,
load_func=algorithm_transformer.callable.read_feature, load_func=load_func or algorithm.read_feature,
save_func=algorithm_transformer.callable.write_feature, save_func=save_func or algorithm.write_feature,
model_path=model_path, model_path=model_file,
features_dir=features_dir, features_dir=features_dir,
transform_extra_arguments=transform_extra_arguments, transform_extra_arguments=transform_extra_arguments,
fit_extra_arguments=fit_extra_arguments, fit_extra_arguments=fit_extra_arguments,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment